From ec792309657ac0785247ecb8efb30b65117483c0 Mon Sep 17 00:00:00 2001 From: mrq Date: Fri, 1 Nov 2024 21:30:06 -0500 Subject: [PATCH] shuffled web UI options hidden by cfg.experimental to its own tab, expose early exit selection to inferencing (it kinda works naively, still need to implement self-speculation) --- vall_e/__main__.py | 5 +++++ vall_e/inference.py | 7 +++++++ vall_e/models/ar.py | 5 ++++- vall_e/models/ar_nar.py | 5 +++++ vall_e/models/arch/llama.py | 5 ++++- vall_e/models/base.py | 14 ++++++++++++-- vall_e/models/nar.py | 13 +++++++++++-- vall_e/webui.py | 30 ++++++++++++++++++++---------- 8 files changed, 68 insertions(+), 16 deletions(-) diff --git a/vall_e/__main__.py b/vall_e/__main__.py index 6308554..c6745c8 100755 --- a/vall_e/__main__.py +++ b/vall_e/__main__.py @@ -47,6 +47,9 @@ def main(): parser.add_argument("--entropix-sampling", action="store_true") + parser.add_argument("--layer-skip", action="store_true") + parser.add_argument("--layer-skip-exit-layer", type=int, default=None) + parser.add_argument("--seed", type=int, default=None) parser.add_argument("--device", type=str, default=None) @@ -81,6 +84,8 @@ def main(): mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta, dry_multiplier=args.dry_multiplier, dry_base=args.dry_base, dry_allowed_length=args.dry_allowed_length, entropix_sampling=args.entropix_sampling, + layer_skip=args.layer_skip, + layer_skip_exit_layer=args.layer_skip_exit_layer, seed=args.seed, ) diff --git a/vall_e/inference.py b/vall_e/inference.py index c9e7786..9365826 100755 --- a/vall_e/inference.py +++ b/vall_e/inference.py @@ -221,6 +221,9 @@ class TTS(): # entropix_sampling=False, # + layer_skip=False, + layer_skip_exit_layer=-1, + # seed = None, out_path=None, @@ -270,6 +273,8 @@ class TTS(): sampling_dry_base=dry_base, sampling_dry_allowed_length=dry_allowed_length, sampling_entropix=entropix_sampling, + sampling_layer_skip=layer_skip, + sampling_layer_skip_exit_layer=layer_skip_exit_layer, disable_tqdm=not tqdm, use_lora=use_lora, @@ -319,6 +324,8 @@ class TTS(): sampling_dry_base=dry_base, sampling_dry_allowed_length=dry_allowed_length, sampling_entropix=entropix_sampling, + sampling_layer_skip=layer_skip, + sampling_layer_skip_exit_layer=layer_skip_exit_layer, disable_tqdm=not tqdm, use_lora=use_lora, diff --git a/vall_e/models/ar.py b/vall_e/models/ar.py index d355d15..e17763c 100644 --- a/vall_e/models/ar.py +++ b/vall_e/models/ar.py @@ -38,7 +38,7 @@ class AR(Base): tone_list: list[Tensor] | None = None, len_list: list[Tensor] | None = None, - training: bool | None = None, + training: bool | int | None = None, max_steps: int = 1000, max_levels: int = 0, @@ -60,6 +60,9 @@ class AR(Base): sampling_dry_multiplier=0.0, sampling_dry_base=1.75, sampling_dry_allowed_length=2, + sampling_entropix=False, + sampling_layer_skip: bool = False, + sampling_layer_skip_exit_layer: int = -1, disable_tqdm=False, use_lora=None, diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 697743a..6563b83 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -66,6 +66,8 @@ class AR_NAR(Base): sampling_dry_base=1.75, sampling_dry_allowed_length=2, sampling_entropix=False, + sampling_layer_skip: bool = False, + sampling_layer_skip_exit_layer: int = -1, disable_tqdm=False, use_lora=None, @@ -326,6 +328,9 @@ class AR_NAR(Base): output = super().forward( inputs=inputs, state=state, + + layer_skip_exit_layer=sampling_layer_skip_exit_layer, + output_attentions=sampling_entropix, ) logits, state = output.logits, output.state diff --git a/vall_e/models/arch/llama.py b/vall_e/models/arch/llama.py index 4efefa7..1550def 100644 --- a/vall_e/models/arch/llama.py +++ b/vall_e/models/arch/llama.py @@ -358,7 +358,7 @@ class LlamaModel_Adapted(LlamaModel): output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - early_exit_layer: Optional[int] = -1, + exit_layer: Optional[int] = -1, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -451,6 +451,9 @@ class LlamaModel_Adapted(LlamaModel): if output_attentions: all_self_attns += (layer_outputs[1],) + if 0 <= exit_layer and exit_layer <= l: + break + hidden_states = self.norm(hidden_states) # add hidden states from the last decoder layer diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 031489b..489f2ba 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -826,6 +826,9 @@ class Base(nn.Module): position_ids = None, state = None, + + layer_skip_exit_layer = -1, + output_attentions = False, output_hidden_states = False, ): @@ -848,9 +851,13 @@ class Base(nn.Module): output_hidden_states=output_hidden_states, return_dict=True, ) + if self.n_experts > 1 and self.training: kwargs["output_router_logits"] = True + if self.layerskip and 0 <= layer_skip_exit_layer and layer_skip_exit_layer < self.n_layers: + kwargs["exit_layer"] = layer_skip_exit_layer + output = self.model(**kwargs) x = output["last_hidden_state"] @@ -1436,8 +1443,10 @@ class Base(nn.Module): quant_levels: int | list[int] | Tensor | None = None, state: dict | list | None = None, - output_attentions = False, - output_hidden_states = False, + layer_skip_exit_layer: int = -1, + + output_attentions: bool = False, + output_hidden_states: bool = False, ): x_list = self.inputs_to_embeddings( inputs, quant_levels ) x, m = list_to_tensor(x_list) @@ -1477,6 +1486,7 @@ class Base(nn.Module): position_ids=position_ids, output_attentions = output_attentions, output_hidden_states = output_hidden_states, + layer_skip_exit_layer = layer_skip_exit_layer, ) logits = output.logits diff --git a/vall_e/models/nar.py b/vall_e/models/nar.py index 89b0ebe..e2a4734 100644 --- a/vall_e/models/nar.py +++ b/vall_e/models/nar.py @@ -35,11 +35,13 @@ class NAR(Base): tone_list: list[Tensor] | None = None, len_list: list[Tensor] | None = None, - training: bool | None = None, + training: bool | int | None = None, max_steps: int = 1000, max_levels: int = 0, - max_resp_context: int = -1, + + input_prompt_prefix: bool = False, + prefix_silence: float = 1.0, sampling_temperature: float = 1.0, sampling_min_temperature: float = -1.0, @@ -52,8 +54,15 @@ class NAR(Base): sampling_beam_width: int = 0, sampling_mirostat_tau: float = 0.0, sampling_mirostat_eta: float = 0.1, + sampling_dry_multiplier=0.0, + sampling_dry_base=1.75, + sampling_dry_allowed_length=2, + sampling_entropix=False, + sampling_layer_skip: bool = False, + sampling_layer_skip_exit_layer: int = -1, disable_tqdm=False, + use_lora=None, ): device = text_list[0].device batch_size = len(text_list) diff --git a/vall_e/webui.py b/vall_e/webui.py index 1c29f8a..b8de13c 100644 --- a/vall_e/webui.py +++ b/vall_e/webui.py @@ -192,6 +192,8 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ): parser.add_argument("--dry-base", type=float, default=kwargs["dry-base"]) parser.add_argument("--dry-allowed-length", type=int, default=kwargs["dry-allowed-length"]) parser.add_argument("--entropix-sampling", action="store_true") + parser.add_argument("--layer-skip", action="store_true") + parser.add_argument("--layer-skip-exit-layer", type=int, default=kwargs["layer-skip-exit-layer"]) args, unknown = parser.parse_known_args() tmp = tempfile.NamedTemporaryFile(suffix='.wav') @@ -203,6 +205,9 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ): if kwargs.pop("entropix-sampling", False): args.entropix_sampling = True + + if kwargs.pop("layer-skip", False): + args.layer_skip = True tts = init_tts() @@ -236,7 +241,9 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ): dry_multiplier=args.dry_multiplier, dry_base=args.dry_base, dry_allowed_length=args.dry_allowed_length, - entropix_sampling=args.entropix_sampling + entropix_sampling=args.entropix_sampling, + layer_skip=args.layer_skip, + layer_skip_exit_layer=args.layer_skip_exit_layer, ) wav = wav.squeeze(0).cpu().numpy() @@ -372,19 +379,12 @@ with ui: with gr.Tab("Basic Settings"): with gr.Row(): layout["inference_tts"]["inputs"]["max-seconds"] = gr.Slider(value=12, minimum=1, maximum=32, step=0.1, label="Maximum Seconds", info="Limits how many steps to perform in the AR pass.") - if cfg.experimental: - layout["inference_tts"]["inputs"]["max-nar-levels"] = gr.Slider(value=7, minimum=0, maximum=7, step=1, label="Max NAR Levels", info="Limits how many steps to perform in the NAR pass.") layout["inference_tts"]["inputs"]["input-prompt-length"] = gr.Slider(value=5.0, minimum=0.0, maximum=12.0, step=0.05, label="Input Prompt Repeat/Trim Length", info="Repeats and trims the input prompt down to X seconds. Set 0 to disable.") with gr.Row(): layout["inference_tts"]["inputs"]["ar-temp"] = gr.Slider(value=0.5, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR. (0 to greedy* sample)") layout["inference_tts"]["inputs"]["nar-temp"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (NAR)", info="Modifies the randomness from the samples in the NAR. (0 to greedy sample)") with gr.Row(): - if cfg.experimental: - layout["inference_tts"]["inputs"]["input-prompt-prefix"] = gr.Checkbox(label="Input Prompt as Prefix", info="Treats the input prompt clip as the prefix of the generated sequence.") - layout["inference_tts"]["inputs"]["prefix-silence"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.0, step=0.05, label="Silence Prefix Duration", info="Amount of silence to prefix to the output response before beginning inference.") layout["inference_tts"]["inputs"]["dynamic-sampling"] = gr.Checkbox(label="Dynamic Temperature", info="Dynamically adjusts the temperature based on the highest confident predicted token per sampling step.") - if cfg.experimental: - layout["inference_tts"]["inputs"]["entropix-sampling"] = gr.Checkbox(label="Entropix Sampling", info="Dynamically samples based on entropy/varentropy values from the logits / attention scores.") layout["inference_tts"]["inputs"]["language"] = gr.Dropdown(choices=get_languages(), label="Language", value="en") with gr.Tab("Sampler Settings"): with gr.Row(): @@ -403,6 +403,18 @@ with ui: layout["inference_tts"]["inputs"]["dry-multiplier"] = gr.Slider(value=0.0, minimum=0.0, maximum=8.0, step=0.05, label="DRY Multiplier", info="The multiplying factor for the DRY score penalty (0 to disable DRY sampling).") layout["inference_tts"]["inputs"]["dry-base"] = gr.Slider(value=1.75, minimum=0.0, maximum=8.0, step=0.05, label="DRY Base", info="The base of the exponent in the DRY score penalty") layout["inference_tts"]["inputs"]["dry-allowed-length"] = gr.Slider(value=2, minimum=0, maximum=75, step=1, label="Allowed Length", info="The maximimum length a token can be to perform DRY penalty with.") + if cfg.experimental: + with gr.Tab("Experimental Settings"): + with gr.Row(): + layout["inference_tts"]["inputs"]["max-nar-levels"] = gr.Slider(value=7, minimum=0, maximum=7, step=1, label="Max NAR Levels", info="Limits how many steps to perform in the NAR pass.") + layout["inference_tts"]["inputs"]["input-prompt-prefix"] = gr.Checkbox(label="Input Prompt as Prefix", info="Treats the input prompt clip as the prefix of the generated sequence.") + with gr.Row(): + layout["inference_tts"]["inputs"]["prefix-silence"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.0, step=0.05, label="Silence Prefix Duration", info="Amount of silence to prefix to the output response before beginning inference.") + with gr.Row(): + layout["inference_tts"]["inputs"]["entropix-sampling"] = gr.Checkbox(label="Entropix Sampling", info="Dynamically samples based on entropy/varentropy values from the logits / attention scores.") + with gr.Row(): + layout["inference_tts"]["inputs"]["layer-skip-exit-layer"] = gr.Slider(value=11, minimum=0, maximum=11, step=1, label="Layer Skip Exit Layer", info="Model layer to exit early from.") + layout["inference_tts"]["buttons"]["inference"].click( fn=do_inference_tts, @@ -425,8 +437,6 @@ with ui: layout["inference_stt"]["inputs"]["ar-temp"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR. (0 to greedy sample)") with gr.Row(): layout["inference_stt"]["inputs"]["dynamic-sampling"] = gr.Checkbox(label="Dynamic Temperature", info="Dynamically adjusts the temperature based on the highest confident predicted token per sampling step.") - if cfg.experimental: - layout["inference_stt"]["inputs"]["entropix-sampling"] = gr.Checkbox(label="Entropix Sampling", info="Dynamically samples based on entropy/varentropy values from the logits / attention scores.") layout["inference_stt"]["inputs"]["language"] = gr.Dropdown(choices=get_languages(), label="Language", value="en") with gr.Tab("Sampler Settings"): with gr.Row():