From d229725c76dba4adc9a12ddddeb6a06483ac4faf Mon Sep 17 00:00:00 2001 From: mrq Date: Sun, 3 Nov 2024 18:31:28 -0600 Subject: [PATCH] more adjustments (adjustments of early-exit entropy/varentropy thresholds, default rep pen being 1.5, experimental refine-on-stop, etc.) --- README.md | 9 +++++++-- vall_e/__main__.py | 8 +++++++- vall_e/config.py | 2 +- vall_e/demo.py | 6 ++++++ vall_e/inference.py | 14 ++++++++++++-- vall_e/models/ar_nar.py | 33 ++++++++++++++++++++++++++++++--- vall_e/models/base.py | 27 ++++++++++++++++++--------- vall_e/webui.py | 16 ++++++++++++++-- 8 files changed, 95 insertions(+), 20 deletions(-) diff --git a/README.md b/README.md index 2fae71d..1f46535 100755 --- a/README.md +++ b/README.md @@ -240,8 +240,13 @@ And some experimental sampling flags you can use too (your mileage will ***defin * `--dry-multiplier`: (AR only) performs DRY sampling, the scalar factor. * `--dry-base`: (AR only) for DRY sampling, the base of the exponent factor. * `--dry-allowed-length`: (AR only) for DRY sampling, the window to perform DRY sampling within. -* `--layer-skip` (AR only) enables early-exit layer skipping if the model is confident enough (for compatible models) -* `--layer-skip-exit-layer`: (AR only) maximum layer to use (for compatbiel models) +* `--layer-skip` enables early-exit layer skipping if the model is confident enough (for compatible models) +* `--layer-skip-exit-layer`: maximum layer to use +* `--layer-skip-entropy-threshold`: the maximum the logits' entropy (confidence) needs to be before exiting early +* `--layer-skip-varentropy-threshold`: the maximum the logits' varentropy (confidence spread) needs to be before exiting early +* `--refine-on-stop`: (AR only) uses the last steps' logits for the entire final output sequence, rather than the step-by-step iterative sequence. + + This needs experimenting with to see if there's any downside. + + to-do: compare the probability scores with the original output sequence, and pick the best one. ### Speech-to-Text diff --git a/vall_e/__main__.py b/vall_e/__main__.py index c6745c8..0a3a81f 100755 --- a/vall_e/__main__.py +++ b/vall_e/__main__.py @@ -33,7 +33,7 @@ def main(): parser.add_argument("--top-p", type=float, default=1.0) parser.add_argument("--top-k", type=int, default=0) parser.add_argument("--min-p", type=float, default=0.0) - parser.add_argument("--repetition-penalty", type=float, default=1.125) + parser.add_argument("--repetition-penalty", type=float, default=1.5) parser.add_argument("--repetition-penalty-decay", type=float, default=0.0) parser.add_argument("--length-penalty", type=float, default=0.0) parser.add_argument("--beam-width", type=int, default=0) @@ -49,6 +49,9 @@ def main(): parser.add_argument("--layer-skip", action="store_true") parser.add_argument("--layer-skip-exit-layer", type=int, default=None) + parser.add_argument("--layer-skip-entropy-threshold", type=int, default=0.1) + parser.add_argument("--layer-skip-varentropy-threshold", type=int, default=0.1) + parser.add_argument("--refine-on-stop", action="store_true") parser.add_argument("--seed", type=int, default=None) @@ -86,6 +89,9 @@ def main(): entropix_sampling=args.entropix_sampling, layer_skip=args.layer_skip, layer_skip_exit_layer=args.layer_skip_exit_layer, + layer_skip_entropy_threshold=args.layer_skip_entropy_threshold, + layer_skip_varentropy_threshold=args.layer_skip_varentropy_threshold, + refine_on_stop=args.refine_on_stop, seed=args.seed, ) diff --git a/vall_e/config.py b/vall_e/config.py index 8dbc698..401c2a5 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -115,7 +115,7 @@ class BaseConfig: raise Exception(f'Model path does not exist: {model_path}') # load state dict and copy its stored model config - model_state_dict = [ torch_load( model_path )["config"] | { "path": model_path } ] if model_path and model_path.exists() else [] + model_state_dict = [ torch_load( model_path )["config"] | { "path": model_path, "attention": "auto" } ] if model_path and model_path.exists() else [] lora_state_dict = [ torch_load( lora_path )["config"] | { "path": lora_path } ] if lora_path and lora_path.exists() else [] state = { "models": model_state_dict, "loras": lora_state_dict, "trainer": { "load_state_dict": True } } diff --git a/vall_e/demo.py b/vall_e/demo.py index 61b053a..ef74687 100644 --- a/vall_e/demo.py +++ b/vall_e/demo.py @@ -82,6 +82,12 @@ def main(): parser.add_argument("--dry-allowed-length", type=int, default=2) parser.add_argument("--entropix-sampling", action="store_true") + + parser.add_argument("--layer-skip", action="store_true") + parser.add_argument("--layer-skip-exit-layer", type=int, default=None) + parser.add_argument("--layer-skip-entropy-threshold", type=int, default=0.1) + parser.add_argument("--layer-skip-varentropy-threshold", type=int, default=0.1) + parser.add_argument("--refine-on-stop", action="store_true") parser.add_argument("--seed", type=int, default=None) diff --git a/vall_e/inference.py b/vall_e/inference.py index 5363942..735f2e6 100755 --- a/vall_e/inference.py +++ b/vall_e/inference.py @@ -223,6 +223,10 @@ class TTS(): # layer_skip=False, layer_skip_exit_layer=-1, + layer_skip_entropy_threshold=-1, + layer_skip_varentropy_threshold=-1, + # + refine_on_stop=False, # seed = None, @@ -275,6 +279,7 @@ class TTS(): sampling_entropix=entropix_sampling, sampling_layer_skip=layer_skip, sampling_layer_skip_exit_layer=layer_skip_exit_layer, + sampling_refine_on_stop=refine_on_stop, disable_tqdm=not tqdm, use_lora=use_lora, @@ -326,6 +331,9 @@ class TTS(): sampling_entropix=entropix_sampling, sampling_layer_skip=layer_skip, sampling_layer_skip_exit_layer=layer_skip_exit_layer, + sampling_layer_skip_entropy_threshold=layer_skip_entropy_threshold, + sampling_layer_skip_varentropy_threshold=layer_skip_varentropy_threshold, + sampling_refine_on_stop=refine_on_stop, disable_tqdm=not tqdm, use_lora=use_lora, @@ -338,8 +346,10 @@ class TTS(): sampling_min_temperature=min_nar_temp, sampling_top_p=top_p, sampling_top_k=top_k, sampling_min_p=min_p, sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay, - #sampling_layer_skip=layer_skip, - #sampling_layer_skip_exit_layer=layer_skip_exit_layer, + sampling_layer_skip=layer_skip, + sampling_layer_skip_exit_layer=layer_skip_exit_layer, + sampling_layer_skip_entropy_threshold=layer_skip_entropy_threshold, + sampling_layer_skip_varentropy_threshold=layer_skip_varentropy_threshold, disable_tqdm=not tqdm, use_lora=use_lora, diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 575320f..98fc2e2 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -69,6 +69,10 @@ class AR_NAR(Base): sampling_layer_skip: bool = False, sampling_layer_skip_exit_layer: int = -1, + sampling_layer_skip_entropy_threshold: float = -1, + sampling_layer_skip_varentropy_threshold: float = -1, + + sampling_refine_on_stop: bool = False, disable_tqdm=False, use_lora=None, @@ -208,7 +212,12 @@ class AR_NAR(Base): sampling_layer_skip_variables = {} if sampling_layer_skip else None if sampling_layer_skip: - sampling_layer_skip_variables["max_layer"] = sampling_layer_skip_exit_layer if sampling_layer_skip_exit_layer >= 0 else self.n_layers + if sampling_layer_skip_entropy_threshold >= 0: + sampling_layer_skip_variables["entropy_threshold"] = sampling_layer_skip_entropy_threshold + if sampling_layer_skip_varentropy_threshold >= 0: + sampling_layer_skip_variables["varentropy_threshold"] = sampling_layer_skip_varentropy_threshold + if sampling_layer_skip_exit_layer >= 0: + sampling_layer_skip_variables["max_layer"] = sampling_layer_skip_exit_layer for n in trange( max_levels, desc="NAR", disable=disable_tqdm ): level = prev_list[0].shape[-1] @@ -292,7 +301,12 @@ class AR_NAR(Base): sampling_layer_skip_variables = {} if sampling_layer_skip else None if sampling_layer_skip: - sampling_layer_skip_variables["max_layer"] = sampling_layer_skip_exit_layer if sampling_layer_skip_exit_layer >= 0 else self.n_layers + if sampling_layer_skip_entropy_threshold >= 0: + sampling_layer_skip_variables["entropy_threshold"] = sampling_layer_skip_entropy_threshold + if sampling_layer_skip_varentropy_threshold >= 0: + sampling_layer_skip_variables["varentropy_threshold"] = sampling_layer_skip_varentropy_threshold + if sampling_layer_skip_exit_layer >= 0: + sampling_layer_skip_variables["max_layer"] = sampling_layer_skip_exit_layer for i, sequence in enumerate( sequence_list ): # add to text for STT @@ -377,7 +391,7 @@ class AR_NAR(Base): if sampled.entropy: metrics.append( sampled.entropy ) elif sampled.scores: - metrics.append( [ { "p": p[0] } for p in sampled.scores ] ) + metrics.append( [ { "p": p[0], "exited_layer": output.exited_layer } for p in sampled.scores ] ) if mirostat is not None: mirostat = sampled.scores @@ -409,6 +423,8 @@ class AR_NAR(Base): if stopped.all().item(): break + # to-do for layerskip / speculative sampling: rerun the last sequence again at max depth + if metrics: from ..plot import plot_sample_metrics filename = "metrics" @@ -430,6 +446,17 @@ class AR_NAR(Base): # remove sequence_list = [ sequence_list[i][start_slice[i]:] for i, task in enumerate( task_list ) ] + if sampling_refine_on_stop: + # get how much we need to slice from the end + slice_lengths = [ sequence.shape[-1] for sequence in sequence_list ] + # -1 for the stop token + logits = [ logit[-length-1:-1] for logit, length in zip(logits, slice_lengths) ] + # greedy sample from the sequence + refined_list = [ logit.argmax(dim=-1) for logit in logits ] + # to-do: compare scores + # set the "refined" list as the output + sequence_list = refined_list + return sequence_list diff --git a/vall_e/models/base.py b/vall_e/models/base.py index d8de287..8a2fe19 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -39,7 +39,7 @@ from ..emb.qnt import encode_as_embedding from ..data import get_task_symmap # these seem more elegant than a dict -Logits = namedtuple('Logits', ['logits', 'state', 'aux_loss', 'attentions', 'hidden_states']) +Logits = namedtuple('Logits', ['logits', 'state', 'aux_loss', 'attentions', 'hidden_states', 'exited_layer']) Sampled = namedtuple('Sampled', ['out', 'scores', 'entropy']) LossStats = namedtuple('LossStats', ['loss', 'stats']) @@ -942,7 +942,7 @@ class Base(nn.Module): # but skip the last state, as it already is normalized hidden_states = [ x if i == self.n_layers - 1 else self.model.norm(output.hidden_states[i]) for i, state in enumerate( hidden_states ) ] - return Logits(x, state, aux_loss, attentions, hidden_states) + return Logits(x, state, aux_loss, attentions, hidden_states, None) # takes a bunch of separate lists and parses them into an ordered array of tuples to guide input sequence creation def inputs( @@ -1444,11 +1444,13 @@ class Base(nn.Module): ): # return early if it's "good" enough" # lambda because we need to capture the classifier_quant_levels and mask + exited_layer = self.n_layers def layer_skip_lambda( layer, logits ): + nonlocal exited_layer kwargs = { - "logits_entropy": 0.1, - "logits_varentropy": 0.1, - "min_layer": self.n_layers // 4, + "entropy_threshold": 0.05, + "varentropy_threshold": 0.05, + "min_layer": self.n_layers // 2, "max_layer": self.n_layers, } @@ -1472,9 +1474,15 @@ class Base(nn.Module): # calculate metrics metrics = calculate_entropix_metrics( logits ) - # exit early if "good enough"" - return metrics["logits_entropy"] < kwargs["logits_entropy"] and metrics["logits_varentropy"] < kwargs["logits_varentropy"] + early = metrics["logits_entropy"] <= kwargs["entropy_threshold"] and metrics["logits_varentropy"] <= kwargs["varentropy_threshold"] + + if early: + exited_layer = layer + + #print( layer, early, metrics ) + + return early x_list = self.inputs_to_embeddings( inputs, quant_levels ) @@ -1526,7 +1534,8 @@ class Base(nn.Module): logits = output.logits hidden_states = output.hidden_states - # output projection layer with masking + # output projection layer + # the very, very original implementation multiplied by the mask, but the mask only attends to padding, and the padding gets removed anyways if self.classifier is not None: logits = self.classifier(logits) # * m @@ -1582,7 +1591,7 @@ class Base(nn.Module): self.stats = stats # rewrap, because we're modifying the logits here - return Logits(logits, output.state, output.aux_loss, output.attentions, hidden_states) + return Logits(logits, output.state, output.aux_loss, output.attentions, hidden_states, exited_layer) def sample( self, diff --git a/vall_e/webui.py b/vall_e/webui.py index 15816ae..d76f2a6 100644 --- a/vall_e/webui.py +++ b/vall_e/webui.py @@ -194,6 +194,9 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ): parser.add_argument("--entropix-sampling", action="store_true") parser.add_argument("--layer-skip", action="store_true") parser.add_argument("--layer-skip-exit-layer", type=int, default=kwargs["layer-skip-exit-layer"] if cfg.experimental else -1) + parser.add_argument("--layer-skip-entropy-threshold", type=int, default=kwargs["layer-skip-entropy-threshold"] if cfg.experimental else 0.1) + parser.add_argument("--layer-skip-varentropy-threshold", type=int, default=kwargs["layer-skip-varentropy-threshold"] if cfg.experimental else 0.1) + parser.add_argument("--refine-on-stop", action="store_true") args, unknown = parser.parse_known_args() tmp = tempfile.NamedTemporaryFile(suffix='.wav') @@ -208,6 +211,9 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ): if kwargs.pop("layer-skip", False): args.layer_skip = True + + if kwargs.pop("refine-on-stop", False): + args.refine_on_stop = True tts = init_tts() @@ -242,8 +248,11 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ): dry_base=args.dry_base, dry_allowed_length=args.dry_allowed_length, entropix_sampling=args.entropix_sampling, + layer_skip=args.layer_skip, - layer_skip_exit_layer=args.layer_skip_exit_layer, + layer_skip_entropy_threshold=args.layer_skip_entropy_threshold, + layer_skip_varentropy_threshold=args.layer_skip_varentropy_threshold, + refine_on_stop=args.refine_on_stop, ) wav = wav.squeeze(0).cpu().numpy() @@ -385,6 +394,7 @@ with ui: layout["inference_tts"]["inputs"]["nar-temp"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (NAR)", info="Modifies the randomness from the samples in the NAR. (0 to greedy sample)") with gr.Row(): layout["inference_tts"]["inputs"]["layer-skip"] = gr.Checkbox(label="Layer Skip", info="Performs self-speculative early exit 'sampling'") + layout["inference_tts"]["inputs"]["refine-on-stop"] = gr.Checkbox(label="Refine on ", info="Uses the last step's logits for the AR sequence instead.") layout["inference_tts"]["inputs"]["language"] = gr.Dropdown(choices=get_languages(), label="Language", value="en") with gr.Tab("Sampler Settings"): with gr.Row(): @@ -393,7 +403,7 @@ with ui: layout["inference_tts"]["inputs"]["min-p"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.0, step=0.05, label="Min P") layout["inference_tts"]["inputs"]["beam-width"] = gr.Slider(value=0, minimum=0, maximum=32, step=1, label="Beam Width", info="Number of branches to search through for beam search sampling.") with gr.Row(): - layout["inference_tts"]["inputs"]["repetition-penalty"] = gr.Slider(value=1.125, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty", info="Incurs a penalty to tokens based on how often they appear in a sequence.") + layout["inference_tts"]["inputs"]["repetition-penalty"] = gr.Slider(value=1.5, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty", info="Incurs a penalty to tokens based on how often they appear in a sequence.") layout["inference_tts"]["inputs"]["repetition-penalty-decay"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty Length Decay", info="Modifies the reptition penalty based on how far back in time the token appeared in the sequence.") layout["inference_tts"]["inputs"]["length-penalty"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Length Penalty", info="(AR only) Modifies the probability of a stop token based on the current length of the sequence.") with gr.Row(): @@ -415,6 +425,8 @@ with ui: layout["inference_tts"]["inputs"]["entropix-sampling"] = gr.Checkbox(label="Entropix Sampling", info="Dynamically samples based on entropy/varentropy values from the logits / attention scores.") with gr.Row(): layout["inference_tts"]["inputs"]["layer-skip-exit-layer"] = gr.Slider(value=11, minimum=0, maximum=11, step=1, label="Layer Skip Exit Layer", info="Maximum model layer to exit early from.") + layout["inference_tts"]["inputs"]["layer-skip-entropy-threshold"] = gr.Slider(value=0.1, minimum=0, maximum=1.0, step=0.01, label="Layer Skip Entropy Threshold", info="Entropy threshold for early-exit") + layout["inference_tts"]["inputs"]["layer-skip-varentropy-threshold"] = gr.Slider(value=0.1, minimum=0, maximum=1.0, step=0.01, label="Layer Skip Varentropy Threshold", info="Varentropy threshold for early-exit") layout["inference_tts"]["buttons"]["inference"].click(