NAR-len tweaks (remasks a small amount of tokens per step, it seems to help with reducing the number of steps needed some of the time?, disable CFG for the first half to speed things up)
This commit is contained in:
parent
cf97560e70
commit
9dff68c0c5
|
@ -79,6 +79,9 @@ However, while this solution boasts being lightweight, there are some caveats fo
|
|||
* `hf`-ifying it is possible, but it'd be a chore to set up the tokenizer properly
|
||||
* it still seems like the phase of the moon matters with how it wants to cooperate
|
||||
* some eval tests it seems fine, other times issues like word errors will crop up
|
||||
* the `NAR-len` requires CFGs > 2-ish to cooperate
|
||||
* this isn't *so* much of an issue, but this can lead to user error, and CFG incurs an additional sampling step per step.
|
||||
* guidance distillation would be nice, but distillation in general harms finetuning (assuming this just as likely harms it)
|
||||
|
||||
|
||||
## Notices and Citations
|
||||
|
|
|
@ -96,6 +96,7 @@ It is ***crucial*** to:
|
|||
* without this, you ***will*** get stuttering and unaligned utterances. I do not know why this is such a big problem but I imagine this "interleaves" many different sequences between each step.
|
||||
* use unfiltered/unprocessed logit scores:
|
||||
* not that crucial, but helps stability.
|
||||
* use a CFG strength of at least 2
|
||||
|
||||
It is not required to train a model from scratch to use this modality, as training from existing weights works just as well, if not better (as it can piggyback off the original model).
|
||||
* additional training is still required to help confidence issues and to condition the model to not fall apart for longer durations.
|
||||
|
@ -336,7 +337,7 @@ A bulk of it pertains to modifying `LlamaAttention` and detecting available atte
|
|||
* `fused_attn`: uses an implementation using `triton` (tested on my 7900XTX and V100s), but seems to introduce errors when used to train after a while
|
||||
* `sageattn`: uses [SageAttention](https://github.com/thu-ml/SageAttention).
|
||||
* training under this is untested, but dropout is not applied (yet).
|
||||
* `default`: uses the naive path for hte internal implementation (used for attention-debugging purposed)
|
||||
* `default`: uses the naive path for the internal implementation (used for attention-debugging purposed)
|
||||
* `transformers` Llama\*Attention implementations:
|
||||
* `eager`: default `LlamaAttention`
|
||||
* `sdpa`: integrated `LlamaSdpaAttention` attention model
|
||||
|
|
|
@ -253,6 +253,7 @@ class AR_NAR(Base):
|
|||
max_steps = sampling_kwargs.get("max_steps", 25)
|
||||
refine_on_stop = sampling_kwargs.get("refine_on_stop", False)
|
||||
entropix_sampling = sampling_kwargs.get("entropix_sampling", False)
|
||||
annealed_sampling = sampling_kwargs.get("annealed_sampling", True)
|
||||
|
||||
# greedy sampling is very, very much preferred, but using greedy logit scores later helps enough
|
||||
temperature = sampling_kwargs.pop("temperature", 0.0)
|
||||
|
@ -261,12 +262,14 @@ class AR_NAR(Base):
|
|||
cfg_rescale = sampling_kwargs.pop("cfg_rescale", 0.75)
|
||||
start_noise = sampling_kwargs.get("denoise_start", 0.0)
|
||||
end_noise = sampling_kwargs.get("denoise_end", 1.0)
|
||||
remasking = sampling_kwargs.get("remasking", True)
|
||||
max_steps = math.floor(max_steps * (end_noise - start_noise))
|
||||
|
||||
len_list = [ clamp(l, min_length, max_length) for l in len_list ]
|
||||
|
||||
# force set CFG because too low / no CFG causes issues
|
||||
cfg_strength = max( cfg_strength, 3.0 )
|
||||
minimum_cfg_strength = sampling_kwargs.get("minimum_cfg_strength", 3.0)
|
||||
cfg_strength = max( cfg_strength, minimum_cfg_strength )
|
||||
|
||||
# if we're denoising from an existing sequence
|
||||
if start_noise > 0.0 and resps_list is not None:
|
||||
|
@ -306,8 +309,10 @@ class AR_NAR(Base):
|
|||
annealing = 1.0 - timestep
|
||||
# get noise level, per cosine scheduling
|
||||
noise_p = math.cos( timestep * math.pi * 0.5 )
|
||||
# proportion of tokens to remask
|
||||
remask_p = 1.0 / max_steps if remasking else 0
|
||||
# pick the worst scoring tokens to mask off
|
||||
masked_indices = [ score.topk( max(int( noise_p * seq_len ), 1), dim=-1 ).indices for score, seq_len in zip(scores, len_list) ]
|
||||
masked_indices = [ score.topk( clamp( int( noise_p * seq_len + remask_p * seq_len ), 1, seq_len), dim=-1 ).indices for score, seq_len in zip(scores, len_list) ]
|
||||
# mask off inputs
|
||||
resps_list = [ resp.scatter(0, indices, self.stop_token) for resp, indices in zip( resps_list, masked_indices ) ]
|
||||
# boolean mask
|
||||
|
@ -315,8 +320,12 @@ class AR_NAR(Base):
|
|||
# timestep inputs
|
||||
time_list = [ timestep for _ in range(batch_size) ]
|
||||
|
||||
sampling_temperature = temperature * annealing
|
||||
sampling_cfg = cfg_strength * timestep
|
||||
sampling_temperature = temperature * annealing if annealed_sampling else temperature
|
||||
sampling_cfg = cfg_strength * timestep if annealed_sampling else temperature
|
||||
|
||||
# avoid useless CFG sampling
|
||||
if sampling_cfg < minimum_cfg_strength * 0.5:
|
||||
sampling_cfg = 0
|
||||
|
||||
# setup inputs
|
||||
inputs = super().inputs(
|
||||
|
|
Loading…
Reference in New Issue
Block a user