diff --git a/docs/README.md b/docs/README.md index ab1271e..2b3eb8d 100644 --- a/docs/README.md +++ b/docs/README.md @@ -79,6 +79,9 @@ However, while this solution boasts being lightweight, there are some caveats fo * `hf`-ifying it is possible, but it'd be a chore to set up the tokenizer properly * it still seems like the phase of the moon matters with how it wants to cooperate * some eval tests it seems fine, other times issues like word errors will crop up +* the `NAR-len` requires CFGs > 2-ish to cooperate + * this isn't *so* much of an issue, but this can lead to user error, and CFG incurs an additional sampling step per step. + * guidance distillation would be nice, but distillation in general harms finetuning (assuming this just as likely harms it) ## Notices and Citations diff --git a/docs/models.md b/docs/models.md index 093d565..c11be9e 100644 --- a/docs/models.md +++ b/docs/models.md @@ -96,6 +96,7 @@ It is ***crucial*** to: * without this, you ***will*** get stuttering and unaligned utterances. I do not know why this is such a big problem but I imagine this "interleaves" many different sequences between each step. * use unfiltered/unprocessed logit scores: * not that crucial, but helps stability. +* use a CFG strength of at least 2 It is not required to train a model from scratch to use this modality, as training from existing weights works just as well, if not better (as it can piggyback off the original model). * additional training is still required to help confidence issues and to condition the model to not fall apart for longer durations. @@ -336,7 +337,7 @@ A bulk of it pertains to modifying `LlamaAttention` and detecting available atte * `fused_attn`: uses an implementation using `triton` (tested on my 7900XTX and V100s), but seems to introduce errors when used to train after a while * `sageattn`: uses [SageAttention](https://github.com/thu-ml/SageAttention). * training under this is untested, but dropout is not applied (yet). - * `default`: uses the naive path for hte internal implementation (used for attention-debugging purposed) + * `default`: uses the naive path for the internal implementation (used for attention-debugging purposed) * `transformers` Llama\*Attention implementations: * `eager`: default `LlamaAttention` * `sdpa`: integrated `LlamaSdpaAttention` attention model diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 304d39a..a1e2960 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -253,6 +253,7 @@ class AR_NAR(Base): max_steps = sampling_kwargs.get("max_steps", 25) refine_on_stop = sampling_kwargs.get("refine_on_stop", False) entropix_sampling = sampling_kwargs.get("entropix_sampling", False) + annealed_sampling = sampling_kwargs.get("annealed_sampling", True) # greedy sampling is very, very much preferred, but using greedy logit scores later helps enough temperature = sampling_kwargs.pop("temperature", 0.0) @@ -261,12 +262,14 @@ class AR_NAR(Base): cfg_rescale = sampling_kwargs.pop("cfg_rescale", 0.75) start_noise = sampling_kwargs.get("denoise_start", 0.0) end_noise = sampling_kwargs.get("denoise_end", 1.0) + remasking = sampling_kwargs.get("remasking", True) max_steps = math.floor(max_steps * (end_noise - start_noise)) len_list = [ clamp(l, min_length, max_length) for l in len_list ] # force set CFG because too low / no CFG causes issues - cfg_strength = max( cfg_strength, 3.0 ) + minimum_cfg_strength = sampling_kwargs.get("minimum_cfg_strength", 3.0) + cfg_strength = max( cfg_strength, minimum_cfg_strength ) # if we're denoising from an existing sequence if start_noise > 0.0 and resps_list is not None: @@ -306,8 +309,10 @@ class AR_NAR(Base): annealing = 1.0 - timestep # get noise level, per cosine scheduling noise_p = math.cos( timestep * math.pi * 0.5 ) + # proportion of tokens to remask + remask_p = 1.0 / max_steps if remasking else 0 # pick the worst scoring tokens to mask off - masked_indices = [ score.topk( max(int( noise_p * seq_len ), 1), dim=-1 ).indices for score, seq_len in zip(scores, len_list) ] + masked_indices = [ score.topk( clamp( int( noise_p * seq_len + remask_p * seq_len ), 1, seq_len), dim=-1 ).indices for score, seq_len in zip(scores, len_list) ] # mask off inputs resps_list = [ resp.scatter(0, indices, self.stop_token) for resp, indices in zip( resps_list, masked_indices ) ] # boolean mask @@ -315,8 +320,12 @@ class AR_NAR(Base): # timestep inputs time_list = [ timestep for _ in range(batch_size) ] - sampling_temperature = temperature * annealing - sampling_cfg = cfg_strength * timestep + sampling_temperature = temperature * annealing if annealed_sampling else temperature + sampling_cfg = cfg_strength * timestep if annealed_sampling else temperature + + # avoid useless CFG sampling + if sampling_cfg < minimum_cfg_strength * 0.5: + sampling_cfg = 0 # setup inputs inputs = super().inputs(