NAR-len tweaks (remasks a small amount of tokens per step, it seems to help with reducing the number of steps needed some of the time?, disable CFG for the first half to speed things up)

This commit is contained in:
mrq 2024-12-04 09:30:29 -06:00
parent cf97560e70
commit 9dff68c0c5
3 changed files with 18 additions and 5 deletions

View File

@ -79,6 +79,9 @@ However, while this solution boasts being lightweight, there are some caveats fo
* `hf`-ifying it is possible, but it'd be a chore to set up the tokenizer properly
* it still seems like the phase of the moon matters with how it wants to cooperate
* some eval tests it seems fine, other times issues like word errors will crop up
* the `NAR-len` requires CFGs > 2-ish to cooperate
* this isn't *so* much of an issue, but this can lead to user error, and CFG incurs an additional sampling step per step.
* guidance distillation would be nice, but distillation in general harms finetuning (assuming this just as likely harms it)
## Notices and Citations

View File

@ -96,6 +96,7 @@ It is ***crucial*** to:
* without this, you ***will*** get stuttering and unaligned utterances. I do not know why this is such a big problem but I imagine this "interleaves" many different sequences between each step.
* use unfiltered/unprocessed logit scores:
* not that crucial, but helps stability.
* use a CFG strength of at least 2
It is not required to train a model from scratch to use this modality, as training from existing weights works just as well, if not better (as it can piggyback off the original model).
* additional training is still required to help confidence issues and to condition the model to not fall apart for longer durations.
@ -336,7 +337,7 @@ A bulk of it pertains to modifying `LlamaAttention` and detecting available atte
* `fused_attn`: uses an implementation using `triton` (tested on my 7900XTX and V100s), but seems to introduce errors when used to train after a while
* `sageattn`: uses [SageAttention](https://github.com/thu-ml/SageAttention).
* training under this is untested, but dropout is not applied (yet).
* `default`: uses the naive path for hte internal implementation (used for attention-debugging purposed)
* `default`: uses the naive path for the internal implementation (used for attention-debugging purposed)
* `transformers` Llama\*Attention implementations:
* `eager`: default `LlamaAttention`
* `sdpa`: integrated `LlamaSdpaAttention` attention model

View File

@ -253,6 +253,7 @@ class AR_NAR(Base):
max_steps = sampling_kwargs.get("max_steps", 25)
refine_on_stop = sampling_kwargs.get("refine_on_stop", False)
entropix_sampling = sampling_kwargs.get("entropix_sampling", False)
annealed_sampling = sampling_kwargs.get("annealed_sampling", True)
# greedy sampling is very, very much preferred, but using greedy logit scores later helps enough
temperature = sampling_kwargs.pop("temperature", 0.0)
@ -261,12 +262,14 @@ class AR_NAR(Base):
cfg_rescale = sampling_kwargs.pop("cfg_rescale", 0.75)
start_noise = sampling_kwargs.get("denoise_start", 0.0)
end_noise = sampling_kwargs.get("denoise_end", 1.0)
remasking = sampling_kwargs.get("remasking", True)
max_steps = math.floor(max_steps * (end_noise - start_noise))
len_list = [ clamp(l, min_length, max_length) for l in len_list ]
# force set CFG because too low / no CFG causes issues
cfg_strength = max( cfg_strength, 3.0 )
minimum_cfg_strength = sampling_kwargs.get("minimum_cfg_strength", 3.0)
cfg_strength = max( cfg_strength, minimum_cfg_strength )
# if we're denoising from an existing sequence
if start_noise > 0.0 and resps_list is not None:
@ -306,8 +309,10 @@ class AR_NAR(Base):
annealing = 1.0 - timestep
# get noise level, per cosine scheduling
noise_p = math.cos( timestep * math.pi * 0.5 )
# proportion of tokens to remask
remask_p = 1.0 / max_steps if remasking else 0
# pick the worst scoring tokens to mask off
masked_indices = [ score.topk( max(int( noise_p * seq_len ), 1), dim=-1 ).indices for score, seq_len in zip(scores, len_list) ]
masked_indices = [ score.topk( clamp( int( noise_p * seq_len + remask_p * seq_len ), 1, seq_len), dim=-1 ).indices for score, seq_len in zip(scores, len_list) ]
# mask off inputs
resps_list = [ resp.scatter(0, indices, self.stop_token) for resp, indices in zip( resps_list, masked_indices ) ]
# boolean mask
@ -315,8 +320,12 @@ class AR_NAR(Base):
# timestep inputs
time_list = [ timestep for _ in range(batch_size) ]
sampling_temperature = temperature * annealing
sampling_cfg = cfg_strength * timestep
sampling_temperature = temperature * annealing if annealed_sampling else temperature
sampling_cfg = cfg_strength * timestep if annealed_sampling else temperature
# avoid useless CFG sampling
if sampling_cfg < minimum_cfg_strength * 0.5:
sampling_cfg = 0
# setup inputs
inputs = super().inputs(