actually do CFG sampling for base AR+NAR tasks

This commit is contained in:
mrq 2024-11-12 13:42:39 -06:00
parent 2495a7ef67
commit b09328069e
2 changed files with 9 additions and 4 deletions

View File

@ -486,7 +486,7 @@ class AR_NAR(Base):
null_inputs = super().inputs(
text_list=null_text,
proms_list=null_prom,
resps_list=resps_list,
resps_list=prev_list,
lang_list=lang_list,
tone_list=tone_list,
quant_levels=quant_levels,
@ -496,7 +496,8 @@ class AR_NAR(Base):
quant_levels=quant_levels,
#layer_skip_variables=sampling_layer_skip_variables,
)
for seq_len, logit, null_logit in zip(len_list, output.logits, null_output.logits):
for resp, logit, null_logit in zip(resps_list, output.logits, null_output.logits):
seq_len = resp.shape[0]
logit[-seq_len:] = null_logit[-seq_len:] + ( logit[-seq_len:] - null_logit[-seq_len:] ) * cfg_strength
sampled = super().sample(
@ -655,6 +656,7 @@ class AR_NAR(Base):
# it would technically be faster to just append the new token's embedding to the inputs, but there's a VERY small performance gain from doing it, so it's not worth it
text_list = [ sequence_list[i] if task in text_task else text_list[i] for i, task in enumerate(task_list) ]
resps_list = [ sequence_list[i] if task not in text_task else resps_list[i] for i, task in enumerate(task_list) ]
quant_levels = [ 0 for _ in range( max( batch_size, beam_width ) ) ]
inputs = self.inputs(
text_list=text_list,
@ -664,7 +666,7 @@ class AR_NAR(Base):
tone_list=tone_list,
len_list=len_list,
task_list=task_list,
quant_levels=[ 0 for _ in range( max( batch_size, beam_width ) ) ]
quant_levels=quant_levels,
)
# to-do: find an elegant way to write this
@ -689,7 +691,8 @@ class AR_NAR(Base):
quant_levels=quant_levels,
#layer_skip_variables=sampling_layer_skip_variables,
)
for seq_len, logit, null_logit in zip(len_list, output.logits, null_output.logits):
for resp, logit, null_logit in zip(resps_list, output.logits, null_output.logits):
seq_len = resp.shape[0] + 1
logit[-seq_len:] = null_logit[-seq_len:] + ( logit[-seq_len:] - null_logit[-seq_len:] ) * cfg_strength
logits, state = output.logits, output.state

View File

@ -233,6 +233,7 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
parser.add_argument("--layer-skip-varentropy-threshold", type=int, default=kwargs["layer-skip-varentropy-threshold"])
parser.add_argument("--refine-on-stop", action="store_true")
parser.add_argument("--denoise-start", type=float, default=0.0)
parser.add_argument("--cfg-strength", type=float, default=kwargs['cfg-strength'])
args, unknown = parser.parse_known_args()
if is_windows:
@ -280,6 +281,7 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
prefix_silence=args.prefix_silence,
input_prompt_prefix=args.input_prompt_prefix,
input_prompt_length=args.input_prompt_length,
cfg_strength=args.cfg_strength,
)
with timer("Inferenced in", callback=lambda msg: gr.Info( msg )) as t: