diff --git a/docs/models.md b/docs/models.md index 44c3500..5f6cab7 100644 --- a/docs/models.md +++ b/docs/models.md @@ -35,15 +35,19 @@ Non-autoregressive trainng is performed by having the input tokens from the prev However, having a pure NAR is challenging, as you need to both explicitly provide the duration and provide a "good enough" starting sequence of tokens for the initial sequence. * The former problem is easily "solved" by training a `len` inferencing task, where the given input predicts the requested duration for a given utterance autoregressively. -* The latter however proves to be a bit of a challenge, as this could be anything from random noise to a unique token. - * The current implementation repeats the input prompt's RVQ level 0 as the initial condition, but inferencing fills with stop tokens. This *might* be the problem, but I do not have my `nar-len-llama-8` weights stored anywhere, sadly. -* Testing showed that it's easy to predict the duration, but decoding the first RVQ level accurately proves to be a chore. - * Initially, output seemed chaotic and unreliable, but further experiments showed the model will "work" for a brief moment before going silent. +* The latter however proves to be challenging, as generating tokens from nothing in one step is not possible. + * diffusion solves this, but requires additional steps at best and a separate model at worse, just for one RVQ level. + * however, it's possible to have a similar paradigm to diffusers, but instead iterating upon random noise, masked tokens are iterated per step, and each step picks the most confident tokens per step. + * incidentally, [this paper](https://arxiv.org/abs/2406.05478) demonstrates this in the use of a NAR transformer for image generation + * the normal NAR (RVQ level 1+) does not face this problem, as it's already given a sufficient initial sequence of tokens to work with, and thus only requires one step. One problem exhibited from a NAR is producing arfifacts ("crust") in the final waveform. I believe this is a confidence problem where the wrong token is inferred. * Unfortunately, one solution is to simply train a separate NAR, as this should help bolster the model's NAR capabilities without the AR influencing things, as I imagine being able to both causally and parallel-ly decode tokens harms things. * This is backed by the used `cfg.model.experimental.rvq_levels_p` distribution affecting the model's AR capabilities, as increasing the NAR's share in training causes the AR to perform *less*. * However, this may be simply wrong, but checkpoints that used such distributions felt lobotomized. +* Another solution that may help is to provide two token dropout methods: + * `token_dropout_error`: This will randomly nudge a small percentage of tokens from the prior RVQ level to simulate wrong tokens being predicted. + * `token_dropout_rate`: This will randomly mask off tokens from the prior RVQ level with a mask token, to try and have the model not-strongly-rely on the given input. ## Embeddings diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 39a9965..954907b 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -241,8 +241,8 @@ class AudioEmbedding(nn.Module): if self.capabilities is None: offset = 0 # resp - elif "len" in self.capabilities: - offset = 1 + #elif "len" in self.capabilities: + # offset = 1 elif "nar" not in self.capabilities: offset = 0 elif quant_level > 0: @@ -460,21 +460,15 @@ class Base(nn.Module): if "nar" not in self.capabilities: n_resp_tokens = n_audio_tokens + 1 l_tokens = [n_resp_tokens] * self.n_resp_levels - # NAR-len model + # AR+NAR model elif "len" not in self.capabilities: # +1 to include the stop token n_resp_tokens = n_audio_tokens + ( 1 if self.causal_size > 0 else 0 ) l_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1) - # AR+NAR model + # NAR-len model else: n_resp_tokens = n_audio_tokens - l_tokens = [n_resp_tokens] * (self.n_resp_levels + (1 if split_classifiers else 0)) - - # there seems to be a problem with the NAR-only model with non-unified position IDs............. - """ - if "len" in self.capabilities and not unified_position_ids: - raise Exception("ERROR: model instability for NAR-only model when not using unified position IDs.") - """ + l_tokens = [n_resp_tokens] * (self.n_resp_levels + 1) self.unified_position_ids = unified_position_ids self.interleave = interleave @@ -490,7 +484,8 @@ class Base(nn.Module): # it would be nicer for these to be a token or live inside an embedding self.sep = nn.Parameter(torch.randn(d_model)) - self.dropout_token = nn.Parameter(torch.zeros(d_model)) # zeros sounds nicer than randn for a special value + self.dropout_token = nn.Parameter(torch.randn(d_model)) + self.mask_token = dropout_token # alias (hopefully) to the above if self.version == 1: # legacy n_audio_tokens += (n_tasks - 1) # old models have the task tokens in the prom @@ -521,7 +516,6 @@ class Base(nn.Module): capabilities=self.capabilities, ) - # useless since I actually removed using these with the input processing overhaul... if self.version >= 3: self.langs_emb = Embedding(n_langs, d_model) if n_langs > 0 else None self.tasks_emb = Embedding(n_tasks, d_model) if n_tasks > 0 else None @@ -533,6 +527,7 @@ class Base(nn.Module): # this *might* help for AR and NAR tasks since we explicitly specify the current RVQ level for a sequence, rather than having it "encoded" in the embeddings # this ***might*** let me also unify the proms_emb and resps_embedding if self.version >= 5: + # "len" RVQ level-0 gets an additional token self.rvq_l_emb = Embedding(self.n_resp_levels + (1 if "len" in self.capabilities else 0), d_model) # experimental NAR-only mode @@ -555,6 +550,7 @@ class Base(nn.Module): if attention_backend not in AVAILABLE_ATTENTIONS: raise ValueError(f"Requesting attention `{attention_backend}` but is not available. Currently available: {AVAILABLE_ATTENTIONS}") + # override any requested padding size if attention_backend == "flash_attn_v100": self.l_padding = 32 elif attention_backend == "fused_attn": @@ -1126,8 +1122,9 @@ class Base(nn.Module): embedding = _interleave_sequence_reshape( embeddings ) elif "len" in self.capabilities and quant_level == 0: - assert input_prom is not None, "Guru mediating during training" # fill with the prom as the initial condition + """ + assert input_prom is not None, "Guru mediation" repeat = (input.shape[0] // input_prom.shape[0]) + 1 repeated = input_prom[:, :1].repeat((repeat, 1))[:input.shape[0], :1] @@ -1137,7 +1134,29 @@ class Base(nn.Module): quant_level = 0, ) """ - # fill with "stop" token from the len layer for the NAR-only model + + + # if training + if not input.is_floating_point(): + # get original sequence + embedding = self.resps_emb( + input, + offset = 0, + quant_level = 0, + ) + # randomly replace with mask tokens + for i in range( embedding.shape[0] ): + # a paper said to do this + if random.random() > 0.8: + continue + embedding[i] = self.dropout_token + # if inferencing + else: + # fill with mask tokens + embedding = torch.concat([ self.dropout_token.unsqueeze(0) for _ in range( input.shape[0] ) ]) + + """ + # fill with filler token from the len layer for the NAR-only model filler_token = 12 embedding = self.resps_emb( # self.dropout_token.repeat((input.shape[0], 1)), @@ -1165,9 +1184,11 @@ class Base(nn.Module): ) else: offset = 0 + """ if "len" in self.capabilities: offset = 1 - elif "nar" not in self.capabilities: + """ + if "nar" not in self.capabilities: offset = 0 elif quant_level > 0: offset = 1 @@ -1676,8 +1697,10 @@ class Base(nn.Module): if quant_levels is not None and "ar" in self.capabilities: logits = [ ban_tokens(logit, tokens=[self.stop_token]) for logit, l in zip( logits, map(len, prev_list) ) ] # (AR-len) disable extraneous tokens + """ if quant_levels is None and "len" in self.capabilities: logits = [ ban_tokens(logit, tokens=[*range(11, logit.shape[-1])]) for logit, l in zip( logits, map(len, prev_list) ) ] + """ # perform repetition penalizing if "len" not in self.capabilities and prev_list is not None and repetition_penalty != 1.0: diff --git a/vall_e/models/nar.py b/vall_e/models/nar.py index 6e5822c..e6382b7 100644 --- a/vall_e/models/nar.py +++ b/vall_e/models/nar.py @@ -216,9 +216,11 @@ class NAR(Base): max_levels = self.n_resp_levels # fill with mock tokens - # to-do: repeat with the input prompt, as per training #prev_list = [ torch.tensor([ self.stop_token for _ in range(resp_len) ], device=device, dtype=torch.int16) for resp_len in len_list ] - prev_list = [ repeat_extend_audio( prom, resp_len ) for resp_len, prom in zip(len_list, proms_list) ] + #prev_list = [ repeat_extend_audio( prom, resp_len ) for resp_len, prom in zip(len_list, proms_list) ] + + prev_list = [ torch.concat([ self.dropout_token.unsqueeze(0) for _ in range( resp_len ) ]) for resp_len in len_list ] + #prev_list = [ None for resp_len in len_list ] # to-do: figure out why this fails when I copy some things from ar_nar for n in trange( max_levels, desc="NAR", disable=disable_tqdm ):