diff --git a/docs/models_v2.md b/docs/models_v2.md index 62978bf..1fc10ef 100644 --- a/docs/models_v2.md +++ b/docs/models_v2.md @@ -152,6 +152,14 @@ These settings should be avoided: * however, it seems there's a regression that caused this to stop working consistently * disabling this falls back to explicitly training a `len` task (like the old implementation) +## Samplers + +To-do: Remember what I was going to jot down here + +Sampling code is effectively the same, with the twist of instead outputting the logits for all codebooks at `dim=0`. + +The NAR-demasking step will account for this automatically, and has dials and knobs to adjust whether to mask off independent of other codebook levels, or for all codebook levels at a given timestep. + ## Benefits and Caveats To be evaluated thoroughly. diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index bb30e27..813fdda 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -296,9 +296,14 @@ class AR_NAR(Base): cfg_rescale = sampling_kwargs.pop("cfg_rescale", 0.75) start_noise = sampling_kwargs.get("denoise_start", 0.0) end_noise = sampling_kwargs.get("denoise_end", 1.0) - remasking = sampling_kwargs.get("remasking", True) max_steps = math.floor(max_steps * (end_noise - start_noise)) + largest_score = 1.0 + smallest_score = 0.0 # -float("inf") + + score_masked_only = sampling_kwargs.pop("sampling_scores_masked_only", False) + remasking = sampling_kwargs.get("sampling_scores_remask", False) + # to specify the initial mask used vc_list = sampling_kwargs.pop("vc_list", None) vc_threshold = sampling_kwargs.pop("vc_threshold", 0.25) @@ -326,7 +331,7 @@ class AR_NAR(Base): # gen masking ratio noise_p = math.cos( start_noise * math.pi * 0.5 ) # generate scoring mask (because the above mask will get masked off per the scores, so we do not need to mask beforehand) - scores = [ torch.tensor( [ 1.0 if random.random() < noise_p else 0.0 for _ in range( seq_len ) ], dtype=torch.float32, device=device ) for seq_len in len_list ] + scores = [ torch.tensor( [ smallest_score if random.random() < noise_p else largest_score for _ in range( seq_len ) ], dtype=torch.float32, device=device ) for seq_len in len_list ] else: # fill with masked tokens (even though they get masked anyways) resps_list = [ torch.ones((seq_len,), dtype=torch.int16, device=device) * self.mask_token for seq_len in len_list ] @@ -350,7 +355,7 @@ class AR_NAR(Base): remask_p = 1.0 / (max_steps * 2) if remasking else 0 mask_p = noise_p + remask_p # pick the worst scoring tokens to mask off - masked_indices = [ score.topk( clamp( int( mask_p * seq_len ), 1, seq_len - step), dim=-1 ).indices for score, seq_len in zip(scores, len_list) ] + masked_indices = [ score.topk( clamp( int( mask_p * seq_len ), 1, seq_len - step), dim=-1, largest=False ).indices for score, seq_len in zip(scores, len_list) ] # normal masking if vc_list is None or timestep >= vc_threshold: @@ -450,15 +455,11 @@ class AR_NAR(Base): sampled_ids = filtered_sampled.ids # keep unmasked tokens resps_list = [ torch.where( masked, input_ids, resps ).to(torch.int16) for masked, input_ids, resps in zip( is_masked, sampled_ids, resps_list ) ] - # get probability scores - scores = [ - # conjugate to have worse scoring tokens picked for topk - 1.0 - - # only keep scores of tokens we are predicting (and ignore the tokens previously finalized) - torch.where( masked, torch.tensor([score for index, score in enumerate(scores)], device=device), torch.ones(masked.shape, device=device) ) - # use unmodified logit scores for this, as it offers better stability - for scores, masked in zip( unfiltered_sampled.scores, is_masked ) - ] + # update scores, only updating tokens that were masked off, and force keeping unmasked tokens + if score_masked_only: + scores = [ torch.where( masked, scores.t(), smallest_score ) for masked, scores in zip( is_masked, sampled.scores ) ] + else: + scores = [ scores for scores in sampled.scores ] return resps_list diff --git a/vall_e/models/ar_nar_v2.py b/vall_e/models/ar_nar_v2.py index 99baf21..bf4149b 100644 --- a/vall_e/models/ar_nar_v2.py +++ b/vall_e/models/ar_nar_v2.py @@ -255,9 +255,15 @@ class AR_NAR_V2(Base_V2): cfg_rescale = sampling_kwargs.pop("cfg_rescale", 0.75) start_noise = sampling_kwargs.get("denoise_start", 0.0) end_noise = sampling_kwargs.get("denoise_end", 1.0) - remasking = sampling_kwargs.get("remasking", True) max_steps = math.floor(max_steps * (end_noise - start_noise)) + largest_score = 1.0 + smallest_score = 0.0 # -float("inf") + + score_masked_only = sampling_kwargs.pop("sampling_scores_masked_only", False) + score_flatten = sampling_kwargs.pop("sampling_scores_flatten", False) + remasking = sampling_kwargs.get("sampling_scores_remask", False) + # to specify the initial mask used vc_list = sampling_kwargs.pop("vc_list", None) vc_threshold = sampling_kwargs.pop("vc_threshold", 0.25) @@ -291,7 +297,7 @@ class AR_NAR_V2(Base_V2): remask_p = 1.0 / (max_steps * 2) if remasking else 0 mask_p = noise_p + remask_p # pick the worst scoring tokens to mask off - masked_indices = [ score.topk( clamp( int( mask_p * seq_len ), 1, seq_len - step), dim=0 ).indices for score, seq_len in zip(scores, len_list) ] + masked_indices = [ score.topk( clamp( int( mask_p * seq_len ), 1, seq_len - step), dim=0, largest=False ).indices for score, seq_len in zip(scores, len_list) ] # normal masking # mask off inputs @@ -353,8 +359,15 @@ class AR_NAR_V2(Base_V2): # update resps, filling in the masked tokens with the new tokens resps_list = [ torch.where( masked, ids.t(), resps ).to(torch.int16) for masked, ids, resps in zip( is_masked, sampled.ids, resps_list ) ] - # update scores, filling in the - scores = [ 1.0 - torch.where( masked, scores.t(), 1 ) for masked, scores in zip( is_masked, sampled.scores ) ] + # update scores, only updating tokens that were masked off, and force keeping unmasked tokens + if score_masked_only: + scores = [ torch.where( masked, scores.t(), smallest_score ) for masked, scores in zip( is_masked, sampled.scores ) ] + else: + scores = [ scores.t() for scores in sampled.scores ] + + # drop all levels at the timestep instead + if score_flatten: + scores = [ score.mean(dim=0).repeat( score.shape[0], 1 ) for score in scores ] return resps_list diff --git a/vall_e/webui.py b/vall_e/webui.py index f2694f6..83201ff 100644 --- a/vall_e/webui.py +++ b/vall_e/webui.py @@ -222,15 +222,15 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ): parser.add_argument("--split-text-by", type=str, default=kwargs["split-text-by"]) parser.add_argument("--context-history", type=int, default=kwargs["context-history"]) parser.add_argument("--input-prompt-length", type=float, default=kwargs["input-prompt-length"]) - parser.add_argument("--input-prompt-prefix", action='store_true', default=kwargs["input-prompt-prefix"]) + #parser.add_argument("--input-prompt-prefix", action='store_true', default=kwargs["input-prompt-prefix"]) parser.add_argument("--max-duration", type=int, default=int(kwargs["max-duration"]*cfg.dataset.frames_per_second)) - parser.add_argument("--max-levels", type=int, default=kwargs["max-levels"]) + #parser.add_argument("--max-levels", type=int, default=kwargs["max-levels"]) parser.add_argument("--max-steps", type=int, default=kwargs["max-steps"]) parser.add_argument("--ar-temperature", type=float, default=kwargs["ar-temperature"]) parser.add_argument("--nar-temperature", type=float, default=kwargs["nar-temperature"]) parser.add_argument("--min-ar-temperature", type=float, default=kwargs["min-ar-temperature"]) parser.add_argument("--min-nar-temperature", type=float, default=kwargs["min-nar-temperature"]) - parser.add_argument("--prefix-silence", type=float, default=kwargs["prefix-silence"]) + #parser.add_argument("--prefix-silence", type=float, default=kwargs["prefix-silence"]) parser.add_argument("--top-p", type=float, default=kwargs["top-p"]) parser.add_argument("--top-k", type=int, default=kwargs["top-k"]) parser.add_argument("--top-no", type=float, default=kwargs["top-no"]) @@ -238,6 +238,7 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ): parser.add_argument("--repetition-penalty", type=float, default=kwargs["repetition-penalty"]) parser.add_argument("--repetition-penalty-decay", type=float, default=kwargs["repetition-penalty-decay"]) parser.add_argument("--length-penalty", type=float, default=kwargs["length-penalty"]) + """ parser.add_argument("--beam-width", type=int, default=kwargs["beam-width"]) parser.add_argument("--mirostat-tau", type=float, default=kwargs["mirostat-tau"]) parser.add_argument("--mirostat-eta", type=float, default=kwargs["mirostat-eta"]) @@ -249,10 +250,16 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ): parser.add_argument("--layer-skip-exit-layer", type=int, default=kwargs["layer-skip-exit-layer"]) parser.add_argument("--layer-skip-entropy-threshold", type=int, default=kwargs["layer-skip-entropy-threshold"]) parser.add_argument("--layer-skip-varentropy-threshold", type=int, default=kwargs["layer-skip-varentropy-threshold"]) + """ parser.add_argument("--refine-on-stop", action="store_true") parser.add_argument("--denoise-start", type=float, default=0.0) parser.add_argument("--cfg-strength", type=float, default=kwargs['cfg-strength']) parser.add_argument("--cfg-rescale", type=float, default=kwargs['cfg-rescale']) + + parser.add_argument("--sampling-scores-masked-only", action="store_true") + parser.add_argument("--sampling-scores-flatten", action="store_true") + parser.add_argument("--sampling-scores-remask", action="store_true") + args, unknown = parser.parse_known_args() if is_windows: @@ -279,6 +286,15 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ): if kwargs.pop("play", False): args.play = True + + if kwargs.pop("sampling-scores-masked-only", False): + args.sampling_scores_masked_only = True + + if kwargs.pop("sampling-scores-flatten", False): + args.sampling_scores_flatten = True + + if kwargs.pop("sampling-scores-remask", False): + args.sampling_scores_remask = True if args.split_text_by == "lines": args.split_text_by = "\n" @@ -298,28 +314,32 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ): phonemize=not args.no_phonemize, voice_convert=args.voice_convert, max_steps=args.max_steps, - max_levels=args.max_levels, + #max_levels=args.max_levels, max_duration=args.max_duration, ar_temperature=args.ar_temperature, nar_temperature=args.nar_temperature, min_ar_temperature=args.min_ar_temperature, min_nar_temperature=args.min_nar_temperature, top_p=args.top_p, top_k=args.top_k, min_p=args.min_p, top_no=args.top_no, repetition_penalty=args.repetition_penalty, repetition_penalty_decay=args.repetition_penalty_decay, length_penalty=args.length_penalty, - beam_width=args.beam_width, - mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta, - dry_multiplier=args.dry_multiplier, dry_base=args.dry_base, dry_allowed_length=args.dry_allowed_length, - entropix_sampling=args.entropix_sampling, - layer_skip=args.layer_skip, - layer_skip_exit_layer=args.layer_skip_exit_layer, - layer_skip_entropy_threshold=args.layer_skip_entropy_threshold, - layer_skip_varentropy_threshold=args.layer_skip_varentropy_threshold, - refine_on_stop=args.refine_on_stop, + #beam_width=args.beam_width, + #mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta, + #dry_multiplier=args.dry_multiplier, dry_base=args.dry_base, dry_allowed_length=args.dry_allowed_length, + #entropix_sampling=args.entropix_sampling, + #layer_skip=args.layer_skip, + #layer_skip_exit_layer=args.layer_skip_exit_layer, + #layer_skip_entropy_threshold=args.layer_skip_entropy_threshold, + #layer_skip_varentropy_threshold=args.layer_skip_varentropy_threshold, + #refine_on_stop=args.refine_on_stop, denoise_start=args.denoise_start, - prefix_silence=args.prefix_silence, - input_prompt_prefix=args.input_prompt_prefix, + #prefix_silence=args.prefix_silence, + #input_prompt_prefix=args.input_prompt_prefix, input_prompt_length=args.input_prompt_length, cfg_strength=args.cfg_strength, cfg_rescale=args.cfg_rescale, + + sampling_scores_masked_only=args.sampling_scores_masked_only, + sampling_scores_flatten=args.sampling_scores_flatten, + sampling_scores_remask=args.sampling_scores_remask, ) with timer("Inferenced in", callback=lambda msg: gr.Info( msg )) as t: @@ -497,7 +517,12 @@ with ui: layout["inference_tts"]["inputs"]["repetition-penalty"] = gr.Slider(value=1.0, minimum=0.0, maximum=5.0, step=0.05, label="Repetition Penalty", info="Incurs a penalty to tokens based on how often they appear in a sequence.") layout["inference_tts"]["inputs"]["repetition-penalty-decay"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty Length Decay", info="Modifies the reptition penalty based on how far back in time the token appeared in the sequence.") layout["inference_tts"]["inputs"]["length-penalty"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Length Penalty", info="(AR only) Modifies the probability of a stop token based on the current length of the sequence.") + with gr.Row(): + layout["inference_tts"]["inputs"]["sampling-scores-masked-only"] = gr.Checkbox(label="Sampled Scores: Masked Only", info="(NAR-len only) Update scores for newly generated tokens only") + layout["inference_tts"]["inputs"]["sampling-scores-flattened"] = gr.Checkbox(label="Sampled Scores: Flattened", info="(NAR-len only) Flattens the scores for all codebook levels") + layout["inference_tts"]["inputs"]["sampling-scores-remask"] = gr.Checkbox(label="Sampled Scores: Remask", info="(NAR-len only) Remasks P%% of existing tokens randomly after each step.") # These settings are pretty much not supported anyways + """ with gr.Tab("Experimental Settings", visible=cfg.experimental): with gr.Row(): layout["inference_tts"]["inputs"]["max-levels"] = gr.Slider(value=7, minimum=0, maximum=7, step=1, label="Max NAR Levels", info="Limits how many steps to perform in the NAR pass.") @@ -520,6 +545,7 @@ with ui: layout["inference_tts"]["inputs"]["layer-skip-exit-layer"] = gr.Slider(value=11, minimum=0, maximum=11, step=1, label="Layer Skip Exit Layer", info="Maximum model layer to exit early from.") layout["inference_tts"]["inputs"]["layer-skip-entropy-threshold"] = gr.Slider(value=0.1, minimum=0, maximum=1.0, step=0.01, label="Layer Skip Entropy Threshold", info="Entropy threshold for early-exit") layout["inference_tts"]["inputs"]["layer-skip-varentropy-threshold"] = gr.Slider(value=0.1, minimum=0, maximum=1.0, step=0.01, label="Layer Skip Varentropy Threshold", info="Varentropy threshold for early-exit") + """ layout["inference_tts"]["buttons"]["inference"].click( fn=do_inference_tts, @@ -565,6 +591,7 @@ with ui: layout["inference_stt"]["inputs"]["repetition-penalty"] = gr.Slider(value=1.0, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty", info="Incurs a penalty to tokens based on how often they appear in a sequence.") layout["inference_stt"]["inputs"]["repetition-penalty-decay"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty Length Decay", info="Modifies the reptition penalty based on how far back in time the token appeared in the sequence.") layout["inference_stt"]["inputs"]["length-penalty"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Length Penalty", info="(AR only) Modifies the probability of a stop token based on the current length of the sequence.") + """ with gr.Row(): layout["inference_stt"]["inputs"]["dynamic-sampling"] = gr.Checkbox(label="Dynamic Temperature", info="Dynamically adjusts the temperature based on the highest confident predicted token per sampling step.") layout["inference_stt"]["inputs"]["mirostat-tau"] = gr.Slider(value=0.0, minimum=0.0, maximum=8.0, step=0.05, label="Mirostat τ (Tau)", info="The \"surprise\" value when performing mirostat sampling. 0 to disable.") @@ -573,6 +600,7 @@ with ui: layout["inference_stt"]["inputs"]["dry-multiplier"] = gr.Slider(value=0.0, minimum=0.0, maximum=8.0, step=0.05, label="DRY Multiplier", info="The multiplying factor for the DRY score penalty (0 to disable DRY sampling).") layout["inference_stt"]["inputs"]["dry-base"] = gr.Slider(value=1.75, minimum=0.0, maximum=8.0, step=0.05, label="DRY Base", info="The base of the exponent in the DRY score penalty") layout["inference_stt"]["inputs"]["dry-allowed-length"] = gr.Slider(value=2, minimum=0, maximum=75, step=1, label="Allowed Length", info="The maximimum length a token can be to perform DRY penalty with.") + """ layout["inference_stt"]["buttons"]["inference"].click( fn=do_inference_stt,