diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index fd27e99..d8ed04e 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -486,7 +486,7 @@ class AR_NAR(Base): null_inputs = super().inputs( text_list=null_text, proms_list=null_prom, - resps_list=resps_list, + resps_list=prev_list, lang_list=lang_list, tone_list=tone_list, quant_levels=quant_levels, @@ -496,7 +496,8 @@ class AR_NAR(Base): quant_levels=quant_levels, #layer_skip_variables=sampling_layer_skip_variables, ) - for seq_len, logit, null_logit in zip(len_list, output.logits, null_output.logits): + for resp, logit, null_logit in zip(resps_list, output.logits, null_output.logits): + seq_len = resp.shape[0] logit[-seq_len:] = null_logit[-seq_len:] + ( logit[-seq_len:] - null_logit[-seq_len:] ) * cfg_strength sampled = super().sample( @@ -655,6 +656,7 @@ class AR_NAR(Base): # it would technically be faster to just append the new token's embedding to the inputs, but there's a VERY small performance gain from doing it, so it's not worth it text_list = [ sequence_list[i] if task in text_task else text_list[i] for i, task in enumerate(task_list) ] resps_list = [ sequence_list[i] if task not in text_task else resps_list[i] for i, task in enumerate(task_list) ] + quant_levels = [ 0 for _ in range( max( batch_size, beam_width ) ) ] inputs = self.inputs( text_list=text_list, @@ -664,7 +666,7 @@ class AR_NAR(Base): tone_list=tone_list, len_list=len_list, task_list=task_list, - quant_levels=[ 0 for _ in range( max( batch_size, beam_width ) ) ] + quant_levels=quant_levels, ) # to-do: find an elegant way to write this @@ -689,7 +691,8 @@ class AR_NAR(Base): quant_levels=quant_levels, #layer_skip_variables=sampling_layer_skip_variables, ) - for seq_len, logit, null_logit in zip(len_list, output.logits, null_output.logits): + for resp, logit, null_logit in zip(resps_list, output.logits, null_output.logits): + seq_len = resp.shape[0] + 1 logit[-seq_len:] = null_logit[-seq_len:] + ( logit[-seq_len:] - null_logit[-seq_len:] ) * cfg_strength logits, state = output.logits, output.state diff --git a/vall_e/webui.py b/vall_e/webui.py index f80ca5e..4c2b2db 100644 --- a/vall_e/webui.py +++ b/vall_e/webui.py @@ -233,6 +233,7 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ): parser.add_argument("--layer-skip-varentropy-threshold", type=int, default=kwargs["layer-skip-varentropy-threshold"]) parser.add_argument("--refine-on-stop", action="store_true") parser.add_argument("--denoise-start", type=float, default=0.0) + parser.add_argument("--cfg-strength", type=float, default=kwargs['cfg-strength']) args, unknown = parser.parse_known_args() if is_windows: @@ -280,6 +281,7 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ): prefix_silence=args.prefix_silence, input_prompt_prefix=args.input_prompt_prefix, input_prompt_length=args.input_prompt_length, + cfg_strength=args.cfg_strength, ) with timer("Inferenced in", callback=lambda msg: gr.Info( msg )) as t: