From 71731ed785befcccfbc631814ed6897a3590672b Mon Sep 17 00:00:00 2001 From: mrq Date: Fri, 18 Oct 2024 17:19:52 -0500 Subject: [PATCH] added prefixing with silence (was to test something, currently hidden under cfg.experimental=True) --- vall_e/inference.py | 2 ++ vall_e/models/ar.py | 3 +++ vall_e/models/ar_nar.py | 8 ++++++-- vall_e/webui.py | 14 +++++++++----- 4 files changed, 20 insertions(+), 7 deletions(-) diff --git a/vall_e/inference.py b/vall_e/inference.py index b5817db..6463f18 100755 --- a/vall_e/inference.py +++ b/vall_e/inference.py @@ -184,6 +184,7 @@ class TTS(): # input_prompt_length=0.0, input_prompt_prefix=False, + prefix_silence=0.0, # ar_temp=0.0, nar_temp=0.0, @@ -295,6 +296,7 @@ class TTS(): resps_list = model_ar( text_list=[phns], proms_list=[prom], lang_list=[lang], max_steps=max_ar_steps, input_prompt_prefix=input_prompt_prefix, + prefix_silence=prefix_silence, sampling_temperature=ar_temp, sampling_min_temperature=min_ar_temp, sampling_top_p=top_p, sampling_top_k=top_k, sampling_min_p=min_p, diff --git a/vall_e/models/ar.py b/vall_e/models/ar.py index 95bd435..d355d15 100644 --- a/vall_e/models/ar.py +++ b/vall_e/models/ar.py @@ -43,6 +43,9 @@ class AR(Base): max_steps: int = 1000, max_levels: int = 0, + input_prompt_prefix: bool = False, + prefix_silence: float = 1.0, + sampling_temperature: float = 1.0, sampling_min_temperature: float = -1.0, sampling_top_k: int = -100, diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 8454f53..d7b54bd 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -23,7 +23,7 @@ import logging _logger = logging.getLogger(__name__) -from ..emb.qnt import trim, encode_as_embedding +from ..emb.qnt import trim, encode_as_embedding, get_silence from ..utils import get_devices, setup_logging, timer from .lora import enable_lora @@ -49,6 +49,7 @@ class AR_NAR(Base): max_levels: int = 0, input_prompt_prefix: bool = False, + prefix_silence: float = 1.0, sampling_temperature: float = 1.0, sampling_min_temperature: float = -1.0, @@ -284,6 +285,10 @@ class AR_NAR(Base): elif input_prompt_prefix: start_slice[i] = proms_list[i].shape[0] sequence_list[i], proms_list[i] = proms_list[i][:, 0], sequence_list[i] + elif prefix_silence > 0: + sequence_list[i] = get_silence(prefix_silence, device=sequence_list[i].device) + sequence_list[i] = sequence_list[i][:, 0] + # start_slice[i] = sequence_list[i].shape[0] # get next in sequence for n in trange(max_steps // max(1, self.causal_size), desc="AR", disable=disable_tqdm): @@ -295,7 +300,6 @@ class AR_NAR(Base): # naturally, rep pen wrangles this initial burst of noise, but naively relying on rep_pen is no good, as it fails after ~6 seconds of audio # however, switching to a default sampling temperature with "clean greedy sampled codes" will make the rest of sequence sound as if it were greedy sampled # to-do: tune these values, maybe have it factor based on confidence scores or something - # to-do: see if instead just prefixing with blank audio overcomes the initla noise anyways if low_temperature: enabled = n < low_temperature_range sampling_repetition_penalty = 1.35 if enabled else original_sampling_repetition_penalty diff --git a/vall_e/webui.py b/vall_e/webui.py index b49716a..2a9a33f 100644 --- a/vall_e/webui.py +++ b/vall_e/webui.py @@ -147,14 +147,14 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ): parser.add_argument("--references", type=str, default=kwargs["reference"]) parser.add_argument("--language", type=str, default=kwargs["language"]) parser.add_argument("--input-prompt-length", type=float, default=kwargs["input-prompt-length"]) - #parser.add_argument("--input-prompt-prefix", action='store_true', default=kwargs["input-prompt-prefix"]) - parser.add_argument("--input-prompt-prefix", action='store_true') + parser.add_argument("--input-prompt-prefix", action='store_true', default=kwargs["input-prompt-prefix"] if cfg.experimental else False) parser.add_argument("--max-ar-steps", type=int, default=int(kwargs["max-seconds"]*cfg.dataset.frames_per_second)) - parser.add_argument("--max-nar-levels", type=int, default=0), # kwargs["max-nar-levels"]) + parser.add_argument("--max-nar-levels", type=int, default=kwargs["max-nar-levels"] if cfg.experimental else 0) parser.add_argument("--ar-temp", type=float, default=kwargs["ar-temp"]) parser.add_argument("--nar-temp", type=float, default=kwargs["nar-temp"]) parser.add_argument("--min-ar-temp", type=float, default=kwargs["min-ar-temp"]) parser.add_argument("--min-nar-temp", type=float, default=kwargs["min-nar-temp"]) + parser.add_argument("--prefix-silence", type=float, default=kwargs["prefix-silence"] if cfg.experimental else 0) parser.add_argument("--top-p", type=float, default=kwargs["top-p"]) parser.add_argument("--top-k", type=int, default=kwargs["top-k"]) parser.add_argument("--min-p", type=float, default=kwargs["min-p"]) @@ -195,6 +195,7 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ): max_nar_levels=args.max_nar_levels, input_prompt_length=args.input_prompt_length, input_prompt_prefix=args.input_prompt_prefix, + prefix_silence=args.prefix_silence, ar_temp=args.ar_temp, nar_temp=args.nar_temp, min_ar_temp=args.min_ar_temp, @@ -345,13 +346,16 @@ with ui: with gr.Tab("Basic Settings"): with gr.Row(): layout["inference_tts"]["inputs"]["max-seconds"] = gr.Slider(value=12, minimum=1, maximum=32, step=0.1, label="Maximum Seconds", info="Limits how many steps to perform in the AR pass.") - #layout["inference_tts"]["inputs"]["max-nar-levels"] = gr.Slider(value=7, minimum=0, maximum=7, step=1, label="Max NAR Levels", info="Limits how many steps to perform in the NAR pass.") + if cfg.experimental: + layout["inference_tts"]["inputs"]["max-nar-levels"] = gr.Slider(value=7, minimum=0, maximum=7, step=1, label="Max NAR Levels", info="Limits how many steps to perform in the NAR pass.") layout["inference_tts"]["inputs"]["input-prompt-length"] = gr.Slider(value=5.0, minimum=0.0, maximum=12.0, step=0.05, label="Input Prompt Repeat/Trim Length", info="Repeats and trims the input prompt down to X seconds. Set 0 to disable.") with gr.Row(): layout["inference_tts"]["inputs"]["ar-temp"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR. (0 to greedy* sample)") layout["inference_tts"]["inputs"]["nar-temp"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (NAR)", info="Modifies the randomness from the samples in the NAR. (0 to greedy sample)") with gr.Row(): - #layout["inference_tts"]["inputs"]["input-prompt-prefix"] = gr.Checkbox(label="Input Prompt as Prefix", info="Treats the input prompt clip as the prefix of the generated sequence.") + if cfg.experimental: + layout["inference_tts"]["inputs"]["input-prompt-prefix"] = gr.Checkbox(label="Input Prompt as Prefix", info="Treats the input prompt clip as the prefix of the generated sequence.") + layout["inference_tts"]["inputs"]["prefix-silence"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.0, step=0.05, label="Silence Prefix Duration", info="Amount of silence to prefix to the output response before beginning inference.") layout["inference_tts"]["inputs"]["dynamic-sampling"] = gr.Checkbox(label="Dynamic Temperature", info="Dynamically adjusts the temperature based on the highest confident predicted token per sampling step.") if cfg.experimental: layout["inference_tts"]["inputs"]["entropix-sampling"] = gr.Checkbox(label="Entropix Sampling", info="Dynamically samples based on entropy/varentropy values from the logits / attention scores.")