actually do CFG sampling for base AR+NAR tasks
This commit is contained in:
parent
2495a7ef67
commit
b09328069e
|
@ -486,7 +486,7 @@ class AR_NAR(Base):
|
|||
null_inputs = super().inputs(
|
||||
text_list=null_text,
|
||||
proms_list=null_prom,
|
||||
resps_list=resps_list,
|
||||
resps_list=prev_list,
|
||||
lang_list=lang_list,
|
||||
tone_list=tone_list,
|
||||
quant_levels=quant_levels,
|
||||
|
@ -496,7 +496,8 @@ class AR_NAR(Base):
|
|||
quant_levels=quant_levels,
|
||||
#layer_skip_variables=sampling_layer_skip_variables,
|
||||
)
|
||||
for seq_len, logit, null_logit in zip(len_list, output.logits, null_output.logits):
|
||||
for resp, logit, null_logit in zip(resps_list, output.logits, null_output.logits):
|
||||
seq_len = resp.shape[0]
|
||||
logit[-seq_len:] = null_logit[-seq_len:] + ( logit[-seq_len:] - null_logit[-seq_len:] ) * cfg_strength
|
||||
|
||||
sampled = super().sample(
|
||||
|
@ -655,6 +656,7 @@ class AR_NAR(Base):
|
|||
# it would technically be faster to just append the new token's embedding to the inputs, but there's a VERY small performance gain from doing it, so it's not worth it
|
||||
text_list = [ sequence_list[i] if task in text_task else text_list[i] for i, task in enumerate(task_list) ]
|
||||
resps_list = [ sequence_list[i] if task not in text_task else resps_list[i] for i, task in enumerate(task_list) ]
|
||||
quant_levels = [ 0 for _ in range( max( batch_size, beam_width ) ) ]
|
||||
|
||||
inputs = self.inputs(
|
||||
text_list=text_list,
|
||||
|
@ -664,7 +666,7 @@ class AR_NAR(Base):
|
|||
tone_list=tone_list,
|
||||
len_list=len_list,
|
||||
task_list=task_list,
|
||||
quant_levels=[ 0 for _ in range( max( batch_size, beam_width ) ) ]
|
||||
quant_levels=quant_levels,
|
||||
)
|
||||
|
||||
# to-do: find an elegant way to write this
|
||||
|
@ -689,7 +691,8 @@ class AR_NAR(Base):
|
|||
quant_levels=quant_levels,
|
||||
#layer_skip_variables=sampling_layer_skip_variables,
|
||||
)
|
||||
for seq_len, logit, null_logit in zip(len_list, output.logits, null_output.logits):
|
||||
for resp, logit, null_logit in zip(resps_list, output.logits, null_output.logits):
|
||||
seq_len = resp.shape[0] + 1
|
||||
logit[-seq_len:] = null_logit[-seq_len:] + ( logit[-seq_len:] - null_logit[-seq_len:] ) * cfg_strength
|
||||
|
||||
logits, state = output.logits, output.state
|
||||
|
|
|
@ -233,6 +233,7 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
|||
parser.add_argument("--layer-skip-varentropy-threshold", type=int, default=kwargs["layer-skip-varentropy-threshold"])
|
||||
parser.add_argument("--refine-on-stop", action="store_true")
|
||||
parser.add_argument("--denoise-start", type=float, default=0.0)
|
||||
parser.add_argument("--cfg-strength", type=float, default=kwargs['cfg-strength'])
|
||||
args, unknown = parser.parse_known_args()
|
||||
|
||||
if is_windows:
|
||||
|
@ -280,6 +281,7 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
|||
prefix_silence=args.prefix_silence,
|
||||
input_prompt_prefix=args.input_prompt_prefix,
|
||||
input_prompt_length=args.input_prompt_length,
|
||||
cfg_strength=args.cfg_strength,
|
||||
)
|
||||
|
||||
with timer("Inferenced in", callback=lambda msg: gr.Info( msg )) as t:
|
||||
|
|
Loading…
Reference in New Issue
Block a user