new implementation tweaks and fixes to make it actually better (there were a lot of badwrong things being done that harmed the output quality, will evaluate the model further)

This commit is contained in:
mrq 2025-04-18 20:36:44 -05:00
parent 98d1d8cb1e
commit d9e18037cc
4 changed files with 81 additions and 31 deletions

View File

@ -152,6 +152,14 @@ These settings should be avoided:
* however, it seems there's a regression that caused this to stop working consistently
* disabling this falls back to explicitly training a `len` task (like the old implementation)
## Samplers
To-do: Remember what I was going to jot down here
Sampling code is effectively the same, with the twist of instead outputting the logits for all codebooks at `dim=0`.
The NAR-demasking step will account for this automatically, and has dials and knobs to adjust whether to mask off independent of other codebook levels, or for all codebook levels at a given timestep.
## Benefits and Caveats
To be evaluated thoroughly.

View File

@ -296,9 +296,14 @@ class AR_NAR(Base):
cfg_rescale = sampling_kwargs.pop("cfg_rescale", 0.75)
start_noise = sampling_kwargs.get("denoise_start", 0.0)
end_noise = sampling_kwargs.get("denoise_end", 1.0)
remasking = sampling_kwargs.get("remasking", True)
max_steps = math.floor(max_steps * (end_noise - start_noise))
largest_score = 1.0
smallest_score = 0.0 # -float("inf")
score_masked_only = sampling_kwargs.pop("sampling_scores_masked_only", False)
remasking = sampling_kwargs.get("sampling_scores_remask", False)
# to specify the initial mask used
vc_list = sampling_kwargs.pop("vc_list", None)
vc_threshold = sampling_kwargs.pop("vc_threshold", 0.25)
@ -326,7 +331,7 @@ class AR_NAR(Base):
# gen masking ratio
noise_p = math.cos( start_noise * math.pi * 0.5 )
# generate scoring mask (because the above mask will get masked off per the scores, so we do not need to mask beforehand)
scores = [ torch.tensor( [ 1.0 if random.random() < noise_p else 0.0 for _ in range( seq_len ) ], dtype=torch.float32, device=device ) for seq_len in len_list ]
scores = [ torch.tensor( [ smallest_score if random.random() < noise_p else largest_score for _ in range( seq_len ) ], dtype=torch.float32, device=device ) for seq_len in len_list ]
else:
# fill with masked tokens (even though they get masked anyways)
resps_list = [ torch.ones((seq_len,), dtype=torch.int16, device=device) * self.mask_token for seq_len in len_list ]
@ -350,7 +355,7 @@ class AR_NAR(Base):
remask_p = 1.0 / (max_steps * 2) if remasking else 0
mask_p = noise_p + remask_p
# pick the worst scoring tokens to mask off
masked_indices = [ score.topk( clamp( int( mask_p * seq_len ), 1, seq_len - step), dim=-1 ).indices for score, seq_len in zip(scores, len_list) ]
masked_indices = [ score.topk( clamp( int( mask_p * seq_len ), 1, seq_len - step), dim=-1, largest=False ).indices for score, seq_len in zip(scores, len_list) ]
# normal masking
if vc_list is None or timestep >= vc_threshold:
@ -450,15 +455,11 @@ class AR_NAR(Base):
sampled_ids = filtered_sampled.ids
# keep unmasked tokens
resps_list = [ torch.where( masked, input_ids, resps ).to(torch.int16) for masked, input_ids, resps in zip( is_masked, sampled_ids, resps_list ) ]
# get probability scores
scores = [
# conjugate to have worse scoring tokens picked for topk
1.0 -
# only keep scores of tokens we are predicting (and ignore the tokens previously finalized)
torch.where( masked, torch.tensor([score for index, score in enumerate(scores)], device=device), torch.ones(masked.shape, device=device) )
# use unmodified logit scores for this, as it offers better stability
for scores, masked in zip( unfiltered_sampled.scores, is_masked )
]
# update scores, only updating tokens that were masked off, and force keeping unmasked tokens
if score_masked_only:
scores = [ torch.where( masked, scores.t(), smallest_score ) for masked, scores in zip( is_masked, sampled.scores ) ]
else:
scores = [ scores for scores in sampled.scores ]
return resps_list

View File

@ -255,9 +255,15 @@ class AR_NAR_V2(Base_V2):
cfg_rescale = sampling_kwargs.pop("cfg_rescale", 0.75)
start_noise = sampling_kwargs.get("denoise_start", 0.0)
end_noise = sampling_kwargs.get("denoise_end", 1.0)
remasking = sampling_kwargs.get("remasking", True)
max_steps = math.floor(max_steps * (end_noise - start_noise))
largest_score = 1.0
smallest_score = 0.0 # -float("inf")
score_masked_only = sampling_kwargs.pop("sampling_scores_masked_only", False)
score_flatten = sampling_kwargs.pop("sampling_scores_flatten", False)
remasking = sampling_kwargs.get("sampling_scores_remask", False)
# to specify the initial mask used
vc_list = sampling_kwargs.pop("vc_list", None)
vc_threshold = sampling_kwargs.pop("vc_threshold", 0.25)
@ -291,7 +297,7 @@ class AR_NAR_V2(Base_V2):
remask_p = 1.0 / (max_steps * 2) if remasking else 0
mask_p = noise_p + remask_p
# pick the worst scoring tokens to mask off
masked_indices = [ score.topk( clamp( int( mask_p * seq_len ), 1, seq_len - step), dim=0 ).indices for score, seq_len in zip(scores, len_list) ]
masked_indices = [ score.topk( clamp( int( mask_p * seq_len ), 1, seq_len - step), dim=0, largest=False ).indices for score, seq_len in zip(scores, len_list) ]
# normal masking
# mask off inputs
@ -353,8 +359,15 @@ class AR_NAR_V2(Base_V2):
# update resps, filling in the masked tokens with the new tokens
resps_list = [ torch.where( masked, ids.t(), resps ).to(torch.int16) for masked, ids, resps in zip( is_masked, sampled.ids, resps_list ) ]
# update scores, filling in the
scores = [ 1.0 - torch.where( masked, scores.t(), 1 ) for masked, scores in zip( is_masked, sampled.scores ) ]
# update scores, only updating tokens that were masked off, and force keeping unmasked tokens
if score_masked_only:
scores = [ torch.where( masked, scores.t(), smallest_score ) for masked, scores in zip( is_masked, sampled.scores ) ]
else:
scores = [ scores.t() for scores in sampled.scores ]
# drop all levels at the timestep instead
if score_flatten:
scores = [ score.mean(dim=0).repeat( score.shape[0], 1 ) for score in scores ]
return resps_list

View File

@ -222,15 +222,15 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
parser.add_argument("--split-text-by", type=str, default=kwargs["split-text-by"])
parser.add_argument("--context-history", type=int, default=kwargs["context-history"])
parser.add_argument("--input-prompt-length", type=float, default=kwargs["input-prompt-length"])
parser.add_argument("--input-prompt-prefix", action='store_true', default=kwargs["input-prompt-prefix"])
#parser.add_argument("--input-prompt-prefix", action='store_true', default=kwargs["input-prompt-prefix"])
parser.add_argument("--max-duration", type=int, default=int(kwargs["max-duration"]*cfg.dataset.frames_per_second))
parser.add_argument("--max-levels", type=int, default=kwargs["max-levels"])
#parser.add_argument("--max-levels", type=int, default=kwargs["max-levels"])
parser.add_argument("--max-steps", type=int, default=kwargs["max-steps"])
parser.add_argument("--ar-temperature", type=float, default=kwargs["ar-temperature"])
parser.add_argument("--nar-temperature", type=float, default=kwargs["nar-temperature"])
parser.add_argument("--min-ar-temperature", type=float, default=kwargs["min-ar-temperature"])
parser.add_argument("--min-nar-temperature", type=float, default=kwargs["min-nar-temperature"])
parser.add_argument("--prefix-silence", type=float, default=kwargs["prefix-silence"])
#parser.add_argument("--prefix-silence", type=float, default=kwargs["prefix-silence"])
parser.add_argument("--top-p", type=float, default=kwargs["top-p"])
parser.add_argument("--top-k", type=int, default=kwargs["top-k"])
parser.add_argument("--top-no", type=float, default=kwargs["top-no"])
@ -238,6 +238,7 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
parser.add_argument("--repetition-penalty", type=float, default=kwargs["repetition-penalty"])
parser.add_argument("--repetition-penalty-decay", type=float, default=kwargs["repetition-penalty-decay"])
parser.add_argument("--length-penalty", type=float, default=kwargs["length-penalty"])
"""
parser.add_argument("--beam-width", type=int, default=kwargs["beam-width"])
parser.add_argument("--mirostat-tau", type=float, default=kwargs["mirostat-tau"])
parser.add_argument("--mirostat-eta", type=float, default=kwargs["mirostat-eta"])
@ -249,10 +250,16 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
parser.add_argument("--layer-skip-exit-layer", type=int, default=kwargs["layer-skip-exit-layer"])
parser.add_argument("--layer-skip-entropy-threshold", type=int, default=kwargs["layer-skip-entropy-threshold"])
parser.add_argument("--layer-skip-varentropy-threshold", type=int, default=kwargs["layer-skip-varentropy-threshold"])
"""
parser.add_argument("--refine-on-stop", action="store_true")
parser.add_argument("--denoise-start", type=float, default=0.0)
parser.add_argument("--cfg-strength", type=float, default=kwargs['cfg-strength'])
parser.add_argument("--cfg-rescale", type=float, default=kwargs['cfg-rescale'])
parser.add_argument("--sampling-scores-masked-only", action="store_true")
parser.add_argument("--sampling-scores-flatten", action="store_true")
parser.add_argument("--sampling-scores-remask", action="store_true")
args, unknown = parser.parse_known_args()
if is_windows:
@ -279,6 +286,15 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
if kwargs.pop("play", False):
args.play = True
if kwargs.pop("sampling-scores-masked-only", False):
args.sampling_scores_masked_only = True
if kwargs.pop("sampling-scores-flatten", False):
args.sampling_scores_flatten = True
if kwargs.pop("sampling-scores-remask", False):
args.sampling_scores_remask = True
if args.split_text_by == "lines":
args.split_text_by = "\n"
@ -298,28 +314,32 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
phonemize=not args.no_phonemize,
voice_convert=args.voice_convert,
max_steps=args.max_steps,
max_levels=args.max_levels,
#max_levels=args.max_levels,
max_duration=args.max_duration,
ar_temperature=args.ar_temperature, nar_temperature=args.nar_temperature,
min_ar_temperature=args.min_ar_temperature, min_nar_temperature=args.min_nar_temperature,
top_p=args.top_p, top_k=args.top_k, min_p=args.min_p, top_no=args.top_no,
repetition_penalty=args.repetition_penalty, repetition_penalty_decay=args.repetition_penalty_decay,
length_penalty=args.length_penalty,
beam_width=args.beam_width,
mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta,
dry_multiplier=args.dry_multiplier, dry_base=args.dry_base, dry_allowed_length=args.dry_allowed_length,
entropix_sampling=args.entropix_sampling,
layer_skip=args.layer_skip,
layer_skip_exit_layer=args.layer_skip_exit_layer,
layer_skip_entropy_threshold=args.layer_skip_entropy_threshold,
layer_skip_varentropy_threshold=args.layer_skip_varentropy_threshold,
refine_on_stop=args.refine_on_stop,
#beam_width=args.beam_width,
#mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta,
#dry_multiplier=args.dry_multiplier, dry_base=args.dry_base, dry_allowed_length=args.dry_allowed_length,
#entropix_sampling=args.entropix_sampling,
#layer_skip=args.layer_skip,
#layer_skip_exit_layer=args.layer_skip_exit_layer,
#layer_skip_entropy_threshold=args.layer_skip_entropy_threshold,
#layer_skip_varentropy_threshold=args.layer_skip_varentropy_threshold,
#refine_on_stop=args.refine_on_stop,
denoise_start=args.denoise_start,
prefix_silence=args.prefix_silence,
input_prompt_prefix=args.input_prompt_prefix,
#prefix_silence=args.prefix_silence,
#input_prompt_prefix=args.input_prompt_prefix,
input_prompt_length=args.input_prompt_length,
cfg_strength=args.cfg_strength,
cfg_rescale=args.cfg_rescale,
sampling_scores_masked_only=args.sampling_scores_masked_only,
sampling_scores_flatten=args.sampling_scores_flatten,
sampling_scores_remask=args.sampling_scores_remask,
)
with timer("Inferenced in", callback=lambda msg: gr.Info( msg )) as t:
@ -497,7 +517,12 @@ with ui:
layout["inference_tts"]["inputs"]["repetition-penalty"] = gr.Slider(value=1.0, minimum=0.0, maximum=5.0, step=0.05, label="Repetition Penalty", info="Incurs a penalty to tokens based on how often they appear in a sequence.")
layout["inference_tts"]["inputs"]["repetition-penalty-decay"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty Length Decay", info="Modifies the reptition penalty based on how far back in time the token appeared in the sequence.")
layout["inference_tts"]["inputs"]["length-penalty"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Length Penalty", info="(AR only) Modifies the probability of a stop token based on the current length of the sequence.")
with gr.Row():
layout["inference_tts"]["inputs"]["sampling-scores-masked-only"] = gr.Checkbox(label="Sampled Scores: Masked Only", info="(NAR-len only) Update scores for newly generated tokens only")
layout["inference_tts"]["inputs"]["sampling-scores-flattened"] = gr.Checkbox(label="Sampled Scores: Flattened", info="(NAR-len only) Flattens the scores for all codebook levels")
layout["inference_tts"]["inputs"]["sampling-scores-remask"] = gr.Checkbox(label="Sampled Scores: Remask", info="(NAR-len only) Remasks P%% of existing tokens randomly after each step.")
# These settings are pretty much not supported anyways
"""
with gr.Tab("Experimental Settings", visible=cfg.experimental):
with gr.Row():
layout["inference_tts"]["inputs"]["max-levels"] = gr.Slider(value=7, minimum=0, maximum=7, step=1, label="Max NAR Levels", info="Limits how many steps to perform in the NAR pass.")
@ -520,6 +545,7 @@ with ui:
layout["inference_tts"]["inputs"]["layer-skip-exit-layer"] = gr.Slider(value=11, minimum=0, maximum=11, step=1, label="Layer Skip Exit Layer", info="Maximum model layer to exit early from.")
layout["inference_tts"]["inputs"]["layer-skip-entropy-threshold"] = gr.Slider(value=0.1, minimum=0, maximum=1.0, step=0.01, label="Layer Skip Entropy Threshold", info="Entropy threshold for early-exit")
layout["inference_tts"]["inputs"]["layer-skip-varentropy-threshold"] = gr.Slider(value=0.1, minimum=0, maximum=1.0, step=0.01, label="Layer Skip Varentropy Threshold", info="Varentropy threshold for early-exit")
"""
layout["inference_tts"]["buttons"]["inference"].click(
fn=do_inference_tts,
@ -565,6 +591,7 @@ with ui:
layout["inference_stt"]["inputs"]["repetition-penalty"] = gr.Slider(value=1.0, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty", info="Incurs a penalty to tokens based on how often they appear in a sequence.")
layout["inference_stt"]["inputs"]["repetition-penalty-decay"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty Length Decay", info="Modifies the reptition penalty based on how far back in time the token appeared in the sequence.")
layout["inference_stt"]["inputs"]["length-penalty"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Length Penalty", info="(AR only) Modifies the probability of a stop token based on the current length of the sequence.")
"""
with gr.Row():
layout["inference_stt"]["inputs"]["dynamic-sampling"] = gr.Checkbox(label="Dynamic Temperature", info="Dynamically adjusts the temperature based on the highest confident predicted token per sampling step.")
layout["inference_stt"]["inputs"]["mirostat-tau"] = gr.Slider(value=0.0, minimum=0.0, maximum=8.0, step=0.05, label="Mirostat τ (Tau)", info="The \"surprise\" value when performing mirostat sampling. 0 to disable.")
@ -573,6 +600,7 @@ with ui:
layout["inference_stt"]["inputs"]["dry-multiplier"] = gr.Slider(value=0.0, minimum=0.0, maximum=8.0, step=0.05, label="DRY Multiplier", info="The multiplying factor for the DRY score penalty (0 to disable DRY sampling).")
layout["inference_stt"]["inputs"]["dry-base"] = gr.Slider(value=1.75, minimum=0.0, maximum=8.0, step=0.05, label="DRY Base", info="The base of the exponent in the DRY score penalty")
layout["inference_stt"]["inputs"]["dry-allowed-length"] = gr.Slider(value=2, minimum=0, maximum=75, step=1, label="Allowed Length", info="The maximimum length a token can be to perform DRY penalty with.")
"""
layout["inference_stt"]["buttons"]["inference"].click(
fn=do_inference_stt,