From 0e621354e7fa809972cf25e37a51343ed14f4f23 Mon Sep 17 00:00:00 2001 From: mrq Date: Tue, 19 Nov 2024 10:30:05 -0600 Subject: [PATCH] cleaned up classifier-free guidance logit processing (in order to try and cope with a bad nar-len model) --- vall_e/models/ar_nar.py | 96 ++++++++++++++++------------------------- vall_e/samplers.py | 17 ++++++++ 2 files changed, 55 insertions(+), 58 deletions(-) diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 4aedf8f..55a1cc4 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -27,6 +27,7 @@ from ..emb.qnt import trim, encode_as_embedding, get_silence from ..utils import get_devices, setup_logging, timer, clamp, convert_kwargs from .lora import enable_lora +from ..samplers import cfg_logits text_task = [ "stt" ] @@ -223,8 +224,8 @@ class AR_NAR(Base): if cfg.lora is not None: enable_lora( self, cfg.lora.active_level( level ) if use_lora is None else use_lora ) - """ # to-do: check if gumbel sampling works / helps + """ def log(x, eps = 1e-20): return torch.log(x.clamp(min = eps)) @@ -232,6 +233,18 @@ class AR_NAR(Base): return ((x / max(temperature, 1e-10)) + -log(-log(torch.zeros_like(x).uniform_(0, 1)))).argmax(dim = dim) """ + def log(t, eps=1e-10): + return torch.log(t + eps) + + + def gumbel_noise(t): + noise = torch.zeros_like(t).uniform_(0, 1) + return -log(-log(noise)) + + + def gumbel_sample(t, temperature=1.0, dim=-1): + return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim=dim) + # convert (N)AR specific args sampling_kwargs = convert_kwargs( sampling_kwargs, "ar_" ) @@ -243,6 +256,7 @@ class AR_NAR(Base): temperature = sampling_kwargs.pop("temperature", 1.0) cfg_strength = sampling_kwargs.get("cfg_strength", 3.0) # this really helps keep audio coherent so far + cfg_rescale = sampling_kwargs.pop("cfg_rescale", 0.7) start_noise = sampling_kwargs.get("denoise_start", 0.0) end_noise = sampling_kwargs.get("denoise_end", 1.0) max_steps = math.floor(max_steps * (end_noise - start_noise)) @@ -286,6 +300,7 @@ class AR_NAR(Base): annealing = 1.0 - timestep # get noise level, per cosine scheduling noise_p = math.cos( timestep * math.pi * 0.5 ) + #noise_p = annealing # pick the worst scoring tokens to mask off masked_indices = [ score.topk( max(int( noise_p * seq_len ), 1), dim=-1 ).indices for score, seq_len in zip(scores, len_list) ] # mask off inputs @@ -335,8 +350,8 @@ class AR_NAR(Base): quant_levels=quant_levels, #layer_skip_variables=sampling_layer_skip_variables, ) - for seq_len, logit, null_logit in zip(len_list, output.logits, null_output.logits): - logit[-seq_len:] = null_logit[-seq_len:] + ( logit[-seq_len:] - null_logit[-seq_len:] ) * sampling_cfg + + logits = cfg_logits( logits=output.logits, null=null_output.logits, strength=cfg_strength, rescale=cfg_rescale, lens=[ l for l in len_list ] ) # sample with sampler settings filtered_sampled = super().sample( @@ -361,60 +376,26 @@ class AR_NAR(Base): """ # update previous list of tokens prev_list = resps_list - - # sample with gumbelnoise - # This actually lobotomizes things - #sampled_ids = [ gumbel_sample( logits, temperature=temperature * annealing, dim=-1 ) for logits in filtered_sampled.logits[0] ] + # get sampled tokens sampled_ids = filtered_sampled.ids + # keep unmasked tokens + resps_list = [ torch.where( masked, input_ids, resps ) for masked, input_ids, resps in zip( is_masked, sampled_ids, resps_list ) ] + # get probability scores (conjugate to have worse scoring tokens picked for topk) + scores = [ 1.0 - torch.tensor([score for score in scores], device=device) for scores in filtered_sampled.scores ] + + """ + # maskgct does some funny stuff but it doesn't amount to anything + if annealing < 1.0e-3: + sampled_ids = filtered_sampled.ids + else: + sampled_ids = [ gumbel_sample( logits, temperature=temperature * annealing, dim=-1 ) for logits in filtered_sampled.logits ] # keep unmasked tokens resps_list = [ torch.where( masked, input_ids, resps ) for masked, input_ids, resps in zip( is_masked, sampled_ids, resps_list ) ] # update scores (conjugated to put the worst scores at the top) - scores = [ 1.0 - torch.tensor([score for score in scores], device=device) for scores in filtered_sampled.scores ] - - # refinement step - if refine_on_stop: - inputs = super().inputs( - text_list=text_list, - proms_list=proms_list, - resps_list=resps_list, - lang_list=lang_list, - tone_list=tone_list, - quant_levels=quant_levels, - ) - output = super().forward( - inputs=inputs, - quant_levels=quant_levels, - #layer_skip_variables=sampling_layer_skip_variables, - ) - - logits = output.logits - - if cfg_strength > 0: - null_inputs = super().inputs( - text_list=null_text, - proms_list=null_prom, - resps_list=resps_list, - lang_list=lang_list, - tone_list=tone_list, - quant_levels=quant_levels, - ) - null_output = super().forward( - inputs=null_inputs, - quant_levels=quant_levels, - #layer_skip_variables=sampling_layer_skip_variables, - ) - for seq_len, logit, null_logit in zip(len_list, output.logits, null_output.logits): - logit[-seq_len:] = null_logit[-seq_len:] + ( logit[-seq_len:] - null_logit[-seq_len:] ) * cfg_strength - - logits = [ logit[-length-1:-1] for logit, length in zip(logits, len_list) ] - # greedy sample from the sequence - refined_list = [ logit.argmax(dim=-1) for logit in logits ] - - """ - if cfg.experimental and max_steps > 0: - print( timestep, steps_until_x0, noise_p, resps_list, scores ) - """ + scores = [ torch.tensor([score for score in scores], device=device) for scores in filtered_sampled.scores ] + scores = [ 1.0 - (choice_temperature * annealing * gumbel_noise( score ) + score) for score in scores ] + """ return resps_list @@ -449,6 +430,7 @@ class AR_NAR(Base): max_levels = sampling_kwargs.get("max_levels", 0) cfg_strength = sampling_kwargs.get("cfg_strength", 0.0) + cfg_rescale = sampling_kwargs.pop("cfg_rescale", 0.7) if max_levels == 0: max_levels = self.n_max_levels - 1 @@ -541,9 +523,8 @@ class AR_NAR(Base): quant_levels=quant_levels, #layer_skip_variables=sampling_layer_skip_variables, ) - for resp, logit, null_logit in zip(resps_list, output.logits, null_output.logits): - seq_len = resp.shape[0] - logit[-seq_len:] = null_logit[-seq_len:] + ( logit[-seq_len:] - null_logit[-seq_len:] ) * cfg_strength + + logits = cfg_logits( logits=output.logits, null=null_output.logits, strength=cfg_strength, rescale=cfg_rescale, lens=[ resp.shape[0] for resp in resps_list ] ) sampled = super().sample( logits=logits, @@ -591,6 +572,7 @@ class AR_NAR(Base): temperature = sampling_kwargs.get("temperature", 1.0) cfg_strength = sampling_kwargs.get("cfg_strength", 0.0) + cfg_rescale = sampling_kwargs.pop("cfg_rescale", 0.7) min_temperature = sampling_kwargs.get("min_temperature", -1.0) max_duration = sampling_kwargs.get("max_duration", 500) beam_width = sampling_kwargs.get("beam_width", 0) @@ -736,9 +718,7 @@ class AR_NAR(Base): quant_levels=quant_levels, #layer_skip_variables=sampling_layer_skip_variables, ) - for resp, logit, null_logit in zip(resps_list, output.logits, null_output.logits): - seq_len = resp.shape[0] + 1 - logit[-seq_len:] = null_logit[-seq_len:] + ( logit[-seq_len:] - null_logit[-seq_len:] ) * cfg_strength + logits = cfg_logits( logits=output.logits, null=null_output.logits, strength=cfg_strength, rescale=cfg_rescale, lens=[ resp.shape[0] + 1 for resp in resps_list ] ) logits, state = output.logits, output.state diff --git a/vall_e/samplers.py b/vall_e/samplers.py index 8d8db5a..77afe8d 100644 --- a/vall_e/samplers.py +++ b/vall_e/samplers.py @@ -171,7 +171,24 @@ def top_no_logits_processing( logits, n = 1.0 ): return logits +# perform classifier-free guidance given positive logits and negative/null logits +# some funny nonsense with needing to operate on slices since this is performed before sampling, where the logits are the entire sequence +# (and because the null logits have a shorter input sequence compared to the positive logits) +def cfg_logits( logits, null, strength, lens, rescale=0.0 ): + for i, seq_len in enumerate( lens ): + pos = logits[i][-seq_len:] + neg = null[i][-seq_len:] + summed = neg + (pos - neg) * strength + + if rescale <= 0: + logits[i][-seq_len:] = summed + else: + dims = tuple(range(1, summed.ndim - 1)) + factor = rescale * (pos.std(dims, keepdim=True) / summed.std(dims, keepdim=True)) + (1 - rescale) + logits[i][-seq_len:] = summed * factor + + return logits # Credit to: https://github.com/basusourya/mirostat/ # performs mirostat-based sampling