From 9e3f2e300f12188995d299a190bd3f19b6bd5590 Mon Sep 17 00:00:00 2001 From: mrq Date: Tue, 4 Jun 2024 23:23:31 -0500 Subject: [PATCH] experimental "just have a token for what rvq level we're on" that seems to help all models (mamba almost works, but it might just have to be relegated as a pure AR model) --- vall_e/models/ar_nar.py | 10 ++++++++-- vall_e/models/base.py | 21 ++++++++++++++++++--- 2 files changed, 26 insertions(+), 5 deletions(-) diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 3f7c139..e612bf3 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -167,7 +167,9 @@ class AR_NAR(Base): resps_list=resps_list, targ_list=targ_list, lang_list=lang_list, - tone_list=tone_list + tone_list=tone_list, + + quant_levels=quant_levels, ) return super().forward( @@ -193,6 +195,7 @@ class AR_NAR(Base): resps_list=prev_list, lang_list=lang_list, tone_list=tone_list, + quant_levels=quant_levels, ) logits = super().forward( @@ -336,9 +339,12 @@ def example_usage(): device = "cuda" # mamba seems to ONLY be used as an AR (any NAR attempts lobotomizes it) + """ if "mamba" in cfg.model.arch_type: cfg.model.prom_levels = 1 cfg.model.resp_levels = 1 + """ + cfg.model.loss_factors = {} def tokenize(content): return torch.tensor( cfg.tokenizer.encode(content) ) @@ -390,7 +396,7 @@ def example_usage(): """ model = AR_NAR(**kwargs).to(device) - steps = 250 + steps = 500 optimizer = cfg.hyperparameters.optimizer.lower() if cfg.cfg_path is not None else "prodigy" scheduler = cfg.hyperparameters.scheduler.lower() if cfg.cfg_path is not None else "" diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 8e59982..8575788 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -506,6 +506,7 @@ class Base(nn.Module): self.langs_emb = None self.tones_emb = None self.tasks_emb = None + self.rvq_level_emb = None if self.version == 1: # legacy n_prom_tokens += (self.n_tasks - 1) # old models have the task tokens in the prom @@ -533,6 +534,9 @@ class Base(nn.Module): if self.version >= 4: self.tones_emb = Embedding(self.n_tones, d_model) if self.n_tones > 0 else None + if self.version >= 5: + self.rvq_level_emb = Embedding(self.n_resp_levels, d_model) + self.sep = nn.Parameter(torch.randn(d_model)) # ick, there has to be a better way @@ -872,14 +876,21 @@ class Base(nn.Module): lang_list: list[Tensor] | None = None, tone_list: list[Tensor] | None = None, + + quant_levels: Tensor | None = None ): device = text_list[0].device batch_size = len(text_list) inputs = [ [] for _ in range(batch_size) ] for i in range(batch_size): + quant_level = quant_levels[i] if quant_levels is not None else 0 + if text_list is not None: inputs[i].append( ( "text", text_list[i] ) ) + + inputs[i].append( ( "quant_level", torch.Tensor([ quant_level ]).to(device=device, dtype=torch.int16) ) ) + if proms_list is not None: inputs[i].append( ( "prom", proms_list[i] ) ) if resps_list is not None: @@ -897,11 +908,13 @@ class Base(nn.Module): x_list = [] for batch_index, batch_input in enumerate(inputs): batch = [] - quant_level = quant_levels[batch_index] if quant_levels is not None else None + quant_level = quant_levels[batch_index] if quant_levels is not None else 0 for name, input in batch_input: embedding = None if name == "text": embedding = self.text_emb( input ) + elif name == "quant_level": + embedding = self.rvq_level_emb( input ) elif name == "lang": embedding = self.langs_emb( input ) elif name == "prom": @@ -929,12 +942,13 @@ class Base(nn.Module): # old, "naive" way, no loss factoring if not self.hyper_config.loss_factors: target_list = [] - for batch in inputs: + for batch_index, batch in enumerate(inputs): target = [] + quant_level = quant_levels[batch_index] if quant_levels is not None else None for name, input in batch: if name == "prom": target.append( torch.full_like(input[..., 0], self.ignore_index) ) - elif name in ["text", "lang", "tone", "targ"]: + elif name in ["text", "quant_level", "lang", "tone", "targ"]: target.append( input ) target_list.append( _join( target, torch.tensor(self.ignore_index, device=target[-1].device) ) ) @@ -980,6 +994,7 @@ class Base(nn.Module): input = input[:, quant_level] seq_len = input.shape[0] + logit = logits[i][it:it+seq_len] it += seq_len + 1 # +1 to incorporate the separator