new implementation tweaks and fixes to make it actually better (there were a lot of badwrong things being done that harmed the output quality, will evaluate the model further)
This commit is contained in:
parent
98d1d8cb1e
commit
d9e18037cc
|
@ -152,6 +152,14 @@ These settings should be avoided:
|
|||
* however, it seems there's a regression that caused this to stop working consistently
|
||||
* disabling this falls back to explicitly training a `len` task (like the old implementation)
|
||||
|
||||
## Samplers
|
||||
|
||||
To-do: Remember what I was going to jot down here
|
||||
|
||||
Sampling code is effectively the same, with the twist of instead outputting the logits for all codebooks at `dim=0`.
|
||||
|
||||
The NAR-demasking step will account for this automatically, and has dials and knobs to adjust whether to mask off independent of other codebook levels, or for all codebook levels at a given timestep.
|
||||
|
||||
## Benefits and Caveats
|
||||
|
||||
To be evaluated thoroughly.
|
||||
|
|
|
@ -296,9 +296,14 @@ class AR_NAR(Base):
|
|||
cfg_rescale = sampling_kwargs.pop("cfg_rescale", 0.75)
|
||||
start_noise = sampling_kwargs.get("denoise_start", 0.0)
|
||||
end_noise = sampling_kwargs.get("denoise_end", 1.0)
|
||||
remasking = sampling_kwargs.get("remasking", True)
|
||||
max_steps = math.floor(max_steps * (end_noise - start_noise))
|
||||
|
||||
largest_score = 1.0
|
||||
smallest_score = 0.0 # -float("inf")
|
||||
|
||||
score_masked_only = sampling_kwargs.pop("sampling_scores_masked_only", False)
|
||||
remasking = sampling_kwargs.get("sampling_scores_remask", False)
|
||||
|
||||
# to specify the initial mask used
|
||||
vc_list = sampling_kwargs.pop("vc_list", None)
|
||||
vc_threshold = sampling_kwargs.pop("vc_threshold", 0.25)
|
||||
|
@ -326,7 +331,7 @@ class AR_NAR(Base):
|
|||
# gen masking ratio
|
||||
noise_p = math.cos( start_noise * math.pi * 0.5 )
|
||||
# generate scoring mask (because the above mask will get masked off per the scores, so we do not need to mask beforehand)
|
||||
scores = [ torch.tensor( [ 1.0 if random.random() < noise_p else 0.0 for _ in range( seq_len ) ], dtype=torch.float32, device=device ) for seq_len in len_list ]
|
||||
scores = [ torch.tensor( [ smallest_score if random.random() < noise_p else largest_score for _ in range( seq_len ) ], dtype=torch.float32, device=device ) for seq_len in len_list ]
|
||||
else:
|
||||
# fill with masked tokens (even though they get masked anyways)
|
||||
resps_list = [ torch.ones((seq_len,), dtype=torch.int16, device=device) * self.mask_token for seq_len in len_list ]
|
||||
|
@ -350,7 +355,7 @@ class AR_NAR(Base):
|
|||
remask_p = 1.0 / (max_steps * 2) if remasking else 0
|
||||
mask_p = noise_p + remask_p
|
||||
# pick the worst scoring tokens to mask off
|
||||
masked_indices = [ score.topk( clamp( int( mask_p * seq_len ), 1, seq_len - step), dim=-1 ).indices for score, seq_len in zip(scores, len_list) ]
|
||||
masked_indices = [ score.topk( clamp( int( mask_p * seq_len ), 1, seq_len - step), dim=-1, largest=False ).indices for score, seq_len in zip(scores, len_list) ]
|
||||
|
||||
# normal masking
|
||||
if vc_list is None or timestep >= vc_threshold:
|
||||
|
@ -450,15 +455,11 @@ class AR_NAR(Base):
|
|||
sampled_ids = filtered_sampled.ids
|
||||
# keep unmasked tokens
|
||||
resps_list = [ torch.where( masked, input_ids, resps ).to(torch.int16) for masked, input_ids, resps in zip( is_masked, sampled_ids, resps_list ) ]
|
||||
# get probability scores
|
||||
scores = [
|
||||
# conjugate to have worse scoring tokens picked for topk
|
||||
1.0 -
|
||||
# only keep scores of tokens we are predicting (and ignore the tokens previously finalized)
|
||||
torch.where( masked, torch.tensor([score for index, score in enumerate(scores)], device=device), torch.ones(masked.shape, device=device) )
|
||||
# use unmodified logit scores for this, as it offers better stability
|
||||
for scores, masked in zip( unfiltered_sampled.scores, is_masked )
|
||||
]
|
||||
# update scores, only updating tokens that were masked off, and force keeping unmasked tokens
|
||||
if score_masked_only:
|
||||
scores = [ torch.where( masked, scores.t(), smallest_score ) for masked, scores in zip( is_masked, sampled.scores ) ]
|
||||
else:
|
||||
scores = [ scores for scores in sampled.scores ]
|
||||
|
||||
return resps_list
|
||||
|
||||
|
|
|
@ -255,9 +255,15 @@ class AR_NAR_V2(Base_V2):
|
|||
cfg_rescale = sampling_kwargs.pop("cfg_rescale", 0.75)
|
||||
start_noise = sampling_kwargs.get("denoise_start", 0.0)
|
||||
end_noise = sampling_kwargs.get("denoise_end", 1.0)
|
||||
remasking = sampling_kwargs.get("remasking", True)
|
||||
max_steps = math.floor(max_steps * (end_noise - start_noise))
|
||||
|
||||
largest_score = 1.0
|
||||
smallest_score = 0.0 # -float("inf")
|
||||
|
||||
score_masked_only = sampling_kwargs.pop("sampling_scores_masked_only", False)
|
||||
score_flatten = sampling_kwargs.pop("sampling_scores_flatten", False)
|
||||
remasking = sampling_kwargs.get("sampling_scores_remask", False)
|
||||
|
||||
# to specify the initial mask used
|
||||
vc_list = sampling_kwargs.pop("vc_list", None)
|
||||
vc_threshold = sampling_kwargs.pop("vc_threshold", 0.25)
|
||||
|
@ -291,7 +297,7 @@ class AR_NAR_V2(Base_V2):
|
|||
remask_p = 1.0 / (max_steps * 2) if remasking else 0
|
||||
mask_p = noise_p + remask_p
|
||||
# pick the worst scoring tokens to mask off
|
||||
masked_indices = [ score.topk( clamp( int( mask_p * seq_len ), 1, seq_len - step), dim=0 ).indices for score, seq_len in zip(scores, len_list) ]
|
||||
masked_indices = [ score.topk( clamp( int( mask_p * seq_len ), 1, seq_len - step), dim=0, largest=False ).indices for score, seq_len in zip(scores, len_list) ]
|
||||
|
||||
# normal masking
|
||||
# mask off inputs
|
||||
|
@ -353,8 +359,15 @@ class AR_NAR_V2(Base_V2):
|
|||
|
||||
# update resps, filling in the masked tokens with the new tokens
|
||||
resps_list = [ torch.where( masked, ids.t(), resps ).to(torch.int16) for masked, ids, resps in zip( is_masked, sampled.ids, resps_list ) ]
|
||||
# update scores, filling in the
|
||||
scores = [ 1.0 - torch.where( masked, scores.t(), 1 ) for masked, scores in zip( is_masked, sampled.scores ) ]
|
||||
# update scores, only updating tokens that were masked off, and force keeping unmasked tokens
|
||||
if score_masked_only:
|
||||
scores = [ torch.where( masked, scores.t(), smallest_score ) for masked, scores in zip( is_masked, sampled.scores ) ]
|
||||
else:
|
||||
scores = [ scores.t() for scores in sampled.scores ]
|
||||
|
||||
# drop all levels at the timestep instead
|
||||
if score_flatten:
|
||||
scores = [ score.mean(dim=0).repeat( score.shape[0], 1 ) for score in scores ]
|
||||
|
||||
return resps_list
|
||||
|
||||
|
|
|
@ -222,15 +222,15 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
|||
parser.add_argument("--split-text-by", type=str, default=kwargs["split-text-by"])
|
||||
parser.add_argument("--context-history", type=int, default=kwargs["context-history"])
|
||||
parser.add_argument("--input-prompt-length", type=float, default=kwargs["input-prompt-length"])
|
||||
parser.add_argument("--input-prompt-prefix", action='store_true', default=kwargs["input-prompt-prefix"])
|
||||
#parser.add_argument("--input-prompt-prefix", action='store_true', default=kwargs["input-prompt-prefix"])
|
||||
parser.add_argument("--max-duration", type=int, default=int(kwargs["max-duration"]*cfg.dataset.frames_per_second))
|
||||
parser.add_argument("--max-levels", type=int, default=kwargs["max-levels"])
|
||||
#parser.add_argument("--max-levels", type=int, default=kwargs["max-levels"])
|
||||
parser.add_argument("--max-steps", type=int, default=kwargs["max-steps"])
|
||||
parser.add_argument("--ar-temperature", type=float, default=kwargs["ar-temperature"])
|
||||
parser.add_argument("--nar-temperature", type=float, default=kwargs["nar-temperature"])
|
||||
parser.add_argument("--min-ar-temperature", type=float, default=kwargs["min-ar-temperature"])
|
||||
parser.add_argument("--min-nar-temperature", type=float, default=kwargs["min-nar-temperature"])
|
||||
parser.add_argument("--prefix-silence", type=float, default=kwargs["prefix-silence"])
|
||||
#parser.add_argument("--prefix-silence", type=float, default=kwargs["prefix-silence"])
|
||||
parser.add_argument("--top-p", type=float, default=kwargs["top-p"])
|
||||
parser.add_argument("--top-k", type=int, default=kwargs["top-k"])
|
||||
parser.add_argument("--top-no", type=float, default=kwargs["top-no"])
|
||||
|
@ -238,6 +238,7 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
|||
parser.add_argument("--repetition-penalty", type=float, default=kwargs["repetition-penalty"])
|
||||
parser.add_argument("--repetition-penalty-decay", type=float, default=kwargs["repetition-penalty-decay"])
|
||||
parser.add_argument("--length-penalty", type=float, default=kwargs["length-penalty"])
|
||||
"""
|
||||
parser.add_argument("--beam-width", type=int, default=kwargs["beam-width"])
|
||||
parser.add_argument("--mirostat-tau", type=float, default=kwargs["mirostat-tau"])
|
||||
parser.add_argument("--mirostat-eta", type=float, default=kwargs["mirostat-eta"])
|
||||
|
@ -249,10 +250,16 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
|||
parser.add_argument("--layer-skip-exit-layer", type=int, default=kwargs["layer-skip-exit-layer"])
|
||||
parser.add_argument("--layer-skip-entropy-threshold", type=int, default=kwargs["layer-skip-entropy-threshold"])
|
||||
parser.add_argument("--layer-skip-varentropy-threshold", type=int, default=kwargs["layer-skip-varentropy-threshold"])
|
||||
"""
|
||||
parser.add_argument("--refine-on-stop", action="store_true")
|
||||
parser.add_argument("--denoise-start", type=float, default=0.0)
|
||||
parser.add_argument("--cfg-strength", type=float, default=kwargs['cfg-strength'])
|
||||
parser.add_argument("--cfg-rescale", type=float, default=kwargs['cfg-rescale'])
|
||||
|
||||
parser.add_argument("--sampling-scores-masked-only", action="store_true")
|
||||
parser.add_argument("--sampling-scores-flatten", action="store_true")
|
||||
parser.add_argument("--sampling-scores-remask", action="store_true")
|
||||
|
||||
args, unknown = parser.parse_known_args()
|
||||
|
||||
if is_windows:
|
||||
|
@ -280,6 +287,15 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
|||
if kwargs.pop("play", False):
|
||||
args.play = True
|
||||
|
||||
if kwargs.pop("sampling-scores-masked-only", False):
|
||||
args.sampling_scores_masked_only = True
|
||||
|
||||
if kwargs.pop("sampling-scores-flatten", False):
|
||||
args.sampling_scores_flatten = True
|
||||
|
||||
if kwargs.pop("sampling-scores-remask", False):
|
||||
args.sampling_scores_remask = True
|
||||
|
||||
if args.split_text_by == "lines":
|
||||
args.split_text_by = "\n"
|
||||
elif args.split_text_by == "none":
|
||||
|
@ -298,28 +314,32 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
|||
phonemize=not args.no_phonemize,
|
||||
voice_convert=args.voice_convert,
|
||||
max_steps=args.max_steps,
|
||||
max_levels=args.max_levels,
|
||||
#max_levels=args.max_levels,
|
||||
max_duration=args.max_duration,
|
||||
ar_temperature=args.ar_temperature, nar_temperature=args.nar_temperature,
|
||||
min_ar_temperature=args.min_ar_temperature, min_nar_temperature=args.min_nar_temperature,
|
||||
top_p=args.top_p, top_k=args.top_k, min_p=args.min_p, top_no=args.top_no,
|
||||
repetition_penalty=args.repetition_penalty, repetition_penalty_decay=args.repetition_penalty_decay,
|
||||
length_penalty=args.length_penalty,
|
||||
beam_width=args.beam_width,
|
||||
mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta,
|
||||
dry_multiplier=args.dry_multiplier, dry_base=args.dry_base, dry_allowed_length=args.dry_allowed_length,
|
||||
entropix_sampling=args.entropix_sampling,
|
||||
layer_skip=args.layer_skip,
|
||||
layer_skip_exit_layer=args.layer_skip_exit_layer,
|
||||
layer_skip_entropy_threshold=args.layer_skip_entropy_threshold,
|
||||
layer_skip_varentropy_threshold=args.layer_skip_varentropy_threshold,
|
||||
refine_on_stop=args.refine_on_stop,
|
||||
#beam_width=args.beam_width,
|
||||
#mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta,
|
||||
#dry_multiplier=args.dry_multiplier, dry_base=args.dry_base, dry_allowed_length=args.dry_allowed_length,
|
||||
#entropix_sampling=args.entropix_sampling,
|
||||
#layer_skip=args.layer_skip,
|
||||
#layer_skip_exit_layer=args.layer_skip_exit_layer,
|
||||
#layer_skip_entropy_threshold=args.layer_skip_entropy_threshold,
|
||||
#layer_skip_varentropy_threshold=args.layer_skip_varentropy_threshold,
|
||||
#refine_on_stop=args.refine_on_stop,
|
||||
denoise_start=args.denoise_start,
|
||||
prefix_silence=args.prefix_silence,
|
||||
input_prompt_prefix=args.input_prompt_prefix,
|
||||
#prefix_silence=args.prefix_silence,
|
||||
#input_prompt_prefix=args.input_prompt_prefix,
|
||||
input_prompt_length=args.input_prompt_length,
|
||||
cfg_strength=args.cfg_strength,
|
||||
cfg_rescale=args.cfg_rescale,
|
||||
|
||||
sampling_scores_masked_only=args.sampling_scores_masked_only,
|
||||
sampling_scores_flatten=args.sampling_scores_flatten,
|
||||
sampling_scores_remask=args.sampling_scores_remask,
|
||||
)
|
||||
|
||||
with timer("Inferenced in", callback=lambda msg: gr.Info( msg )) as t:
|
||||
|
@ -497,7 +517,12 @@ with ui:
|
|||
layout["inference_tts"]["inputs"]["repetition-penalty"] = gr.Slider(value=1.0, minimum=0.0, maximum=5.0, step=0.05, label="Repetition Penalty", info="Incurs a penalty to tokens based on how often they appear in a sequence.")
|
||||
layout["inference_tts"]["inputs"]["repetition-penalty-decay"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty Length Decay", info="Modifies the reptition penalty based on how far back in time the token appeared in the sequence.")
|
||||
layout["inference_tts"]["inputs"]["length-penalty"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Length Penalty", info="(AR only) Modifies the probability of a stop token based on the current length of the sequence.")
|
||||
with gr.Row():
|
||||
layout["inference_tts"]["inputs"]["sampling-scores-masked-only"] = gr.Checkbox(label="Sampled Scores: Masked Only", info="(NAR-len only) Update scores for newly generated tokens only")
|
||||
layout["inference_tts"]["inputs"]["sampling-scores-flattened"] = gr.Checkbox(label="Sampled Scores: Flattened", info="(NAR-len only) Flattens the scores for all codebook levels")
|
||||
layout["inference_tts"]["inputs"]["sampling-scores-remask"] = gr.Checkbox(label="Sampled Scores: Remask", info="(NAR-len only) Remasks P%% of existing tokens randomly after each step.")
|
||||
# These settings are pretty much not supported anyways
|
||||
"""
|
||||
with gr.Tab("Experimental Settings", visible=cfg.experimental):
|
||||
with gr.Row():
|
||||
layout["inference_tts"]["inputs"]["max-levels"] = gr.Slider(value=7, minimum=0, maximum=7, step=1, label="Max NAR Levels", info="Limits how many steps to perform in the NAR pass.")
|
||||
|
@ -520,6 +545,7 @@ with ui:
|
|||
layout["inference_tts"]["inputs"]["layer-skip-exit-layer"] = gr.Slider(value=11, minimum=0, maximum=11, step=1, label="Layer Skip Exit Layer", info="Maximum model layer to exit early from.")
|
||||
layout["inference_tts"]["inputs"]["layer-skip-entropy-threshold"] = gr.Slider(value=0.1, minimum=0, maximum=1.0, step=0.01, label="Layer Skip Entropy Threshold", info="Entropy threshold for early-exit")
|
||||
layout["inference_tts"]["inputs"]["layer-skip-varentropy-threshold"] = gr.Slider(value=0.1, minimum=0, maximum=1.0, step=0.01, label="Layer Skip Varentropy Threshold", info="Varentropy threshold for early-exit")
|
||||
"""
|
||||
|
||||
layout["inference_tts"]["buttons"]["inference"].click(
|
||||
fn=do_inference_tts,
|
||||
|
@ -565,6 +591,7 @@ with ui:
|
|||
layout["inference_stt"]["inputs"]["repetition-penalty"] = gr.Slider(value=1.0, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty", info="Incurs a penalty to tokens based on how often they appear in a sequence.")
|
||||
layout["inference_stt"]["inputs"]["repetition-penalty-decay"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty Length Decay", info="Modifies the reptition penalty based on how far back in time the token appeared in the sequence.")
|
||||
layout["inference_stt"]["inputs"]["length-penalty"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Length Penalty", info="(AR only) Modifies the probability of a stop token based on the current length of the sequence.")
|
||||
"""
|
||||
with gr.Row():
|
||||
layout["inference_stt"]["inputs"]["dynamic-sampling"] = gr.Checkbox(label="Dynamic Temperature", info="Dynamically adjusts the temperature based on the highest confident predicted token per sampling step.")
|
||||
layout["inference_stt"]["inputs"]["mirostat-tau"] = gr.Slider(value=0.0, minimum=0.0, maximum=8.0, step=0.05, label="Mirostat τ (Tau)", info="The \"surprise\" value when performing mirostat sampling. 0 to disable.")
|
||||
|
@ -573,6 +600,7 @@ with ui:
|
|||
layout["inference_stt"]["inputs"]["dry-multiplier"] = gr.Slider(value=0.0, minimum=0.0, maximum=8.0, step=0.05, label="DRY Multiplier", info="The multiplying factor for the DRY score penalty (0 to disable DRY sampling).")
|
||||
layout["inference_stt"]["inputs"]["dry-base"] = gr.Slider(value=1.75, minimum=0.0, maximum=8.0, step=0.05, label="DRY Base", info="The base of the exponent in the DRY score penalty")
|
||||
layout["inference_stt"]["inputs"]["dry-allowed-length"] = gr.Slider(value=2, minimum=0, maximum=75, step=1, label="Allowed Length", info="The maximimum length a token can be to perform DRY penalty with.")
|
||||
"""
|
||||
|
||||
layout["inference_stt"]["buttons"]["inference"].click(
|
||||
fn=do_inference_stt,
|
||||
|
|
Loading…
Reference in New Issue
Block a user