From b1f4db39c8dd25377413ccfa27e00b76c3144a95 Mon Sep 17 00:00:00 2001 From: mrq Date: Mon, 11 Nov 2024 20:27:38 -0600 Subject: [PATCH] threw in CFG sampling for normal model as well to experiment with --- vall_e/models/ar_nar.py | 48 +++++++++++++++++++++++++++++++++++++++-- 1 file changed, 46 insertions(+), 2 deletions(-) diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index d9cc325..e2a0f8e 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -323,7 +323,7 @@ class AR_NAR(Base): # keep unmasked tokens resps_list = [ torch.where( masked, input_ids, resps ) for masked, input_ids, resps in zip( is_masked, sampled_ids, resps_list ) ] # update scores (conjugated to put the worst scores at the top) - scores = [ 1.0 - torch.tensor([score for score in scores], device=device) for scores in unfiltered_sampled.scores ] + scores = [ 1.0 - torch.tensor([score for score in scores], device=device) for scores in filtered_sampled.scores ] if cfg.experimental and max_steps > 0: print( timestep, steps_until_x0, noise_p, resps_list, scores ) @@ -356,10 +356,12 @@ class AR_NAR(Base): batch_size = len(resps_list) - max_levels = sampling_kwargs.get("max_levels", 0) # convert NAR specific args sampling_kwargs = convert_kwargs( sampling_kwargs, "nar_" ) + max_levels = sampling_kwargs.get("max_levels", 0) + cfg_strength = sampling_kwargs.get("cfg_strength", 0.0) + if max_levels == 0: max_levels = self.n_max_levels - 1 @@ -395,6 +397,9 @@ class AR_NAR(Base): prev_list = resps_list + null_text = [ torch.tensor([1, 2], device=device, dtype=torch.int16) for _ in range(batch_size) ] + null_prom = [ None for _ in range(batch_size) ] + for n in trange( max_levels, desc="NAR", disable=disable_tqdm ): level = prev_list[0].shape[-1] if level >= max_levels + 1: # min(max_levels + 1, self.n_resp_levels): # commented out to experiment with exceeding trained levels @@ -421,6 +426,23 @@ class AR_NAR(Base): ) logits, state = output.logits, output.state + if cfg_strength > 0: + null_inputs = super().inputs( + text_list=null_text, + proms_list=null_prom, + resps_list=resps_list, + lang_list=lang_list, + tone_list=tone_list, + quant_levels=quant_levels, + ) + null_output = super().forward( + inputs=null_inputs, + quant_levels=quant_levels, + #layer_skip_variables=sampling_layer_skip_variables, + ) + for seq_len, logit, null_logit in zip(len_list, output.logits, null_output.logits): + logit[-seq_len:] = null_logit[-seq_len:] + ( logit[-seq_len:] - null_logit[-seq_len:] ) * cfg_strength + sampled = super().sample( logits=logits, prev_list=prev_list, @@ -465,6 +487,7 @@ class AR_NAR(Base): sampling_kwargs = convert_kwargs( sampling_kwargs, "ar_" ) temperature = sampling_kwargs.get("temperature", 1.0) + cfg_strength = sampling_kwargs.get("cfg_strength", 0.0) min_temperature = sampling_kwargs.get("min_temperature", -1.0) max_duration = sampling_kwargs.get("max_duration", 500) beam_width = sampling_kwargs.get("beam_width", 0) @@ -567,6 +590,9 @@ class AR_NAR(Base): sequence_list[i] = sequence_list[i][:, 0] # start_slice[i] = sequence_list[i].shape[0] + null_text = [ torch.tensor([1, 2], device=device, dtype=torch.int16) for _ in range(batch_size) ] + null_prom = [ None for _ in range(batch_size) ] + # get next in sequence for n in trange(max_duration // max(1, self.causal_size), desc="AR", disable=disable_tqdm): # it would technically be faster to just append the new token's embedding to the inputs, but there's a VERY small performance gain from doing it, so it's not worth it @@ -591,6 +617,24 @@ class AR_NAR(Base): #layer_skip_variables=sampling_layer_skip_variables, output_attentions=entropix_sampling, ) + + if cfg_strength > 0: + null_inputs = super().inputs( + text_list=null_text, + proms_list=null_prom, + resps_list=resps_list, + lang_list=lang_list, + tone_list=tone_list, + quant_levels=quant_levels, + ) + null_output = super().forward( + inputs=null_inputs, + quant_levels=quant_levels, + #layer_skip_variables=sampling_layer_skip_variables, + ) + for seq_len, logit, null_logit in zip(len_list, output.logits, null_output.logits): + logit[-seq_len:] = null_logit[-seq_len:] + ( logit[-seq_len:] - null_logit[-seq_len:] ) * cfg_strength + logits, state = output.logits, output.state sampled = super().sample(