From 9dff68c0c57bb46da1847313b0ea23d44bd3050c Mon Sep 17 00:00:00 2001 From: mrq Date: Wed, 4 Dec 2024 09:30:29 -0600 Subject: [PATCH] NAR-len tweaks (remasks a small amount of tokens per step, it seems to help with reducing the number of steps needed some of the time?, disable CFG for the first half to speed things up) --- docs/README.md | 3 +++ docs/models.md | 3 ++- vall_e/models/ar_nar.py | 17 +++++++++++++---- 3 files changed, 18 insertions(+), 5 deletions(-) diff --git a/docs/README.md b/docs/README.md index ab1271e..2b3eb8d 100644 --- a/docs/README.md +++ b/docs/README.md @@ -79,6 +79,9 @@ However, while this solution boasts being lightweight, there are some caveats fo * `hf`-ifying it is possible, but it'd be a chore to set up the tokenizer properly * it still seems like the phase of the moon matters with how it wants to cooperate * some eval tests it seems fine, other times issues like word errors will crop up +* the `NAR-len` requires CFGs > 2-ish to cooperate + * this isn't *so* much of an issue, but this can lead to user error, and CFG incurs an additional sampling step per step. + * guidance distillation would be nice, but distillation in general harms finetuning (assuming this just as likely harms it) ## Notices and Citations diff --git a/docs/models.md b/docs/models.md index 093d565..c11be9e 100644 --- a/docs/models.md +++ b/docs/models.md @@ -96,6 +96,7 @@ It is ***crucial*** to: * without this, you ***will*** get stuttering and unaligned utterances. I do not know why this is such a big problem but I imagine this "interleaves" many different sequences between each step. * use unfiltered/unprocessed logit scores: * not that crucial, but helps stability. +* use a CFG strength of at least 2 It is not required to train a model from scratch to use this modality, as training from existing weights works just as well, if not better (as it can piggyback off the original model). * additional training is still required to help confidence issues and to condition the model to not fall apart for longer durations. @@ -336,7 +337,7 @@ A bulk of it pertains to modifying `LlamaAttention` and detecting available atte * `fused_attn`: uses an implementation using `triton` (tested on my 7900XTX and V100s), but seems to introduce errors when used to train after a while * `sageattn`: uses [SageAttention](https://github.com/thu-ml/SageAttention). * training under this is untested, but dropout is not applied (yet). - * `default`: uses the naive path for hte internal implementation (used for attention-debugging purposed) + * `default`: uses the naive path for the internal implementation (used for attention-debugging purposed) * `transformers` Llama\*Attention implementations: * `eager`: default `LlamaAttention` * `sdpa`: integrated `LlamaSdpaAttention` attention model diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 304d39a..a1e2960 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -253,6 +253,7 @@ class AR_NAR(Base): max_steps = sampling_kwargs.get("max_steps", 25) refine_on_stop = sampling_kwargs.get("refine_on_stop", False) entropix_sampling = sampling_kwargs.get("entropix_sampling", False) + annealed_sampling = sampling_kwargs.get("annealed_sampling", True) # greedy sampling is very, very much preferred, but using greedy logit scores later helps enough temperature = sampling_kwargs.pop("temperature", 0.0) @@ -261,12 +262,14 @@ class AR_NAR(Base): cfg_rescale = sampling_kwargs.pop("cfg_rescale", 0.75) start_noise = sampling_kwargs.get("denoise_start", 0.0) end_noise = sampling_kwargs.get("denoise_end", 1.0) + remasking = sampling_kwargs.get("remasking", True) max_steps = math.floor(max_steps * (end_noise - start_noise)) len_list = [ clamp(l, min_length, max_length) for l in len_list ] # force set CFG because too low / no CFG causes issues - cfg_strength = max( cfg_strength, 3.0 ) + minimum_cfg_strength = sampling_kwargs.get("minimum_cfg_strength", 3.0) + cfg_strength = max( cfg_strength, minimum_cfg_strength ) # if we're denoising from an existing sequence if start_noise > 0.0 and resps_list is not None: @@ -306,8 +309,10 @@ class AR_NAR(Base): annealing = 1.0 - timestep # get noise level, per cosine scheduling noise_p = math.cos( timestep * math.pi * 0.5 ) + # proportion of tokens to remask + remask_p = 1.0 / max_steps if remasking else 0 # pick the worst scoring tokens to mask off - masked_indices = [ score.topk( max(int( noise_p * seq_len ), 1), dim=-1 ).indices for score, seq_len in zip(scores, len_list) ] + masked_indices = [ score.topk( clamp( int( noise_p * seq_len + remask_p * seq_len ), 1, seq_len), dim=-1 ).indices for score, seq_len in zip(scores, len_list) ] # mask off inputs resps_list = [ resp.scatter(0, indices, self.stop_token) for resp, indices in zip( resps_list, masked_indices ) ] # boolean mask @@ -315,8 +320,12 @@ class AR_NAR(Base): # timestep inputs time_list = [ timestep for _ in range(batch_size) ] - sampling_temperature = temperature * annealing - sampling_cfg = cfg_strength * timestep + sampling_temperature = temperature * annealing if annealed_sampling else temperature + sampling_cfg = cfg_strength * timestep if annealed_sampling else temperature + + # avoid useless CFG sampling + if sampling_cfg < minimum_cfg_strength * 0.5: + sampling_cfg = 0 # setup inputs inputs = super().inputs(