From 48cd1054f97e4608408c7e0a6035e545c1cbba29 Mon Sep 17 00:00:00 2001 From: mrq Date: Tue, 4 Jun 2024 23:48:51 -0500 Subject: [PATCH] madness --- vall_e/config.py | 6 ++++-- vall_e/models/ar_nar.py | 18 ++++++++++++------ vall_e/models/base.py | 14 ++++++++++---- vall_e/models/experimental.py | 2 +- vall_e/train.py | 2 +- 5 files changed, 28 insertions(+), 14 deletions(-) diff --git a/vall_e/config.py b/vall_e/config.py index 5d835c4..4ae3f64 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -214,9 +214,11 @@ class Model: attention: str = "auto" audio_embedding_sums: bool = True dropout: float = 0.1 # adjustable dropout value - loss_factors: dict = field(default_factory=lambda: { "text": 0.1, "prom": 0.0, "resp": 1.0 }) - kv_heads: int = 0 + #loss_factors: dict = field(default_factory=lambda: { "text": 0.1, "prom": 0.0, "resp": 1.0 }) # disable it by default since it causes a little more harm than good + loss_factors: dict = field(default_factory=lambda: {}) + capabilities: list = field(default_factory=lambda: ["ar", "nar"]) experimental: bool = False # for now it sets things to be HF compatible + kv_heads: int = 0 def get(self, name=None): return [ self ] if not name or self.name == name else [] diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index e612bf3..ac3ad15 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -15,6 +15,8 @@ from ..emb.qnt import trim class AR_NAR(Base): @property def causal(self): + if hasattr(self, "config") and self.config: + return "ar" in self.capabilities return True @property @@ -135,9 +137,9 @@ class AR_NAR(Base): index = i return int(index) - quant_levels = torch.Tensor([ generate(0, self.n_resp_levels) for _ in range(batch_size) ]).to(dtype=torch.int16) + quant_levels = torch.Tensor([ generate(0 if self.causal else 1, self.n_resp_levels) for _ in range(batch_size) ]).to(dtype=torch.int16) else: - quant_levels = torch.randint(0, self.n_resp_levels, (batch_size,)) # randomly select a target RVQ-bin level (0 being AR, 1+ being NAR) + quant_levels = torch.randint(0 if self.causal else 1, self.n_resp_levels, (batch_size,)) # randomly select a target RVQ-bin level (0 being AR, 1+ being NAR) """ if cfg.model.p_ar_level == "auto" or cfg.model.p_ar_level is None: quant_levels = torch.randint(0, self.n_resp_levels, (batch_size,)) # randomly select a target RVQ-bin level (0 being AR, 1+ being NAR) @@ -344,7 +346,7 @@ def example_usage(): cfg.model.prom_levels = 1 cfg.model.resp_levels = 1 """ - cfg.model.loss_factors = {} + # cfg.model.loss_factors = {} def tokenize(content): return torch.tensor( cfg.tokenizer.encode(content) ) @@ -396,7 +398,7 @@ def example_usage(): """ model = AR_NAR(**kwargs).to(device) - steps = 500 + steps = 200 optimizer = cfg.hyperparameters.optimizer.lower() if cfg.cfg_path is not None else "prodigy" scheduler = cfg.hyperparameters.scheduler.lower() if cfg.cfg_path is not None else "" @@ -468,7 +470,11 @@ def example_usage(): return engine.eval() - resps_list = engine(text_list, proms_list, max_steps=steps, sampling_temperature=0.95 ) + if "ar" in cfg.model.capabilities: + resps_list = engine(text_list, proms_list, max_steps=steps, sampling_temperature=0.95 ) + else: + resps_list = [ qnt[:, 0].to( device ) ] + if cfg.model.max_levels > 1: resps_list = [r.unsqueeze(-1) for r in resps_list] resps_list = engine( text_list, proms_list, resps_list=resps_list, sampling_temperature=0.2 ) @@ -492,7 +498,7 @@ def example_usage(): 'module': model.state_dict() }, f"./data/{cfg.model.arch_type}.pth" ) - sample("init", 5) + #sample("init", 5) train() sample("final") diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 8575788..278a9e3 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -498,7 +498,6 @@ class Base(nn.Module): self.l_padding = l_padding # +1 to include the stop token - # to-do: undo this dogshit mistake; tasks tokens should be delegated to its own embedding n_prom_tokens = n_tokens n_resp_tokens = n_tokens + (1 if self.causal else 0) # AR requires a stop token to... know when to stop @@ -1009,11 +1008,14 @@ class Base(nn.Module): "logits": [], } - info[name]["targets"].append( input.contiguous() ) - info[name]["logits"].append( logit.contiguous() ) + info[name]["targets"].append( input ) # input.contiguous() + info[name]["logits"].append( logit ) # logit.contiguous() for name, batch in info.items(): loss_factor = self.loss_factor(name) + if name not in ["text", "prom", "resp"]: + continue + if loss_factor == 0.0: continue @@ -1021,7 +1023,11 @@ class Base(nn.Module): inputs = torch.cat( batch["logits"] ) self.loss[name] = F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) * loss_factor - self.stats["acc"][name] = self.accuracy_metric( inputs, targets ) + try: + self.stats["acc"][name] = self.accuracy_metric( inputs, targets ) + except Exception as e: + print( name, inputs.shape, targets.shape, e ) + pass # to-do: compute loss per individual batch to scale per RVQ level """ diff --git a/vall_e/models/experimental.py b/vall_e/models/experimental.py index 052c3d9..471236c 100644 --- a/vall_e/models/experimental.py +++ b/vall_e/models/experimental.py @@ -434,7 +434,7 @@ def example_usage(): stats = {"step": i} batch_size = len(text_list) - quant_levels = None if cfg.model.interleave else torch.randint(0, cfg.model.max_levels, (batch_size,)) + quant_levels = None if cfg.model.interleave else torch.randint(0 if "ar" in cfg.model.capabilities else 1, cfg.model.max_levels, (batch_size,)) if quant_levels is not None: resps_list = [ [] if l == 0 else resp for l, resp in zip(quant_levels, resp_list) ] else: diff --git a/vall_e/train.py b/vall_e/train.py index 7909348..ba2e3e1 100755 --- a/vall_e/train.py +++ b/vall_e/train.py @@ -32,7 +32,7 @@ def train_feeder(engine, batch): quant_levels = None resps_list = [ resp for resp in batch["resps"] ] else: - quant_levels = torch.randint(0, cfg.model.max_levels, (batch_size,)) + quant_levels = torch.randint(0 if "ar" in cfg.model.capabilities else 1, cfg.model.max_levels, (batch_size,)) resps_list = [ [] if l == 0 else resp for l, resp in zip(quant_levels, batch["resps"]) ] input_ids, attention_mask = fold_inputs(