diff --git a/vall_e/__main__.py b/vall_e/__main__.py index f1fcc8a..297a256 100755 --- a/vall_e/__main__.py +++ b/vall_e/__main__.py @@ -33,6 +33,7 @@ def main(): parser.add_argument("--input-prompt-prefix", action="store_true") parser.add_argument("--prefix-silence", type=float, default=0.0) parser.add_argument("--cfg-strength", type=float, default=0.0) + parser.add_argument("--cfg-rescale", type=float, default=0.75) parser.add_argument("--top-p", type=float, default=1.0) parser.add_argument("--top-k", type=int, default=0) @@ -102,6 +103,7 @@ def main(): input_prompt_prefix=args.input_prompt_prefix, prefix_silence=args.prefix_silence, cfg_strength=args.cfg_strength, + cfg_rescale=args.cfg_rescale, ) output = tts.inference( diff --git a/vall_e/demo.py b/vall_e/demo.py index 041af36..9c6e9f9 100644 --- a/vall_e/demo.py +++ b/vall_e/demo.py @@ -71,7 +71,8 @@ def main(): parser.add_argument("--input-prompt-length", type=float, default=3.0) parser.add_argument("--input-prompt-prefix", action="store_true") parser.add_argument("--prefix-silence", type=float, default=0.0) - parser.add_argument("--cfg-strength", type=float, default=3.0) + parser.add_argument("--cfg-strength", type=float, default=1.0) + parser.add_argument("--cfg-rescale", type=float, default=0.75) parser.add_argument("--top-p", type=float, default=1.0) parser.add_argument("--top-k", type=int, default=0) @@ -255,6 +256,7 @@ def main(): input_prompt_prefix=args.input_prompt_prefix, prefix_silence=args.prefix_silence, cfg_strength=args.cfg_strength, + cfg_rescale=args.cfg_rescale, ) # replace values in our template diff --git a/vall_e/inference.py b/vall_e/inference.py index d10bf64..fbc220f 100755 --- a/vall_e/inference.py +++ b/vall_e/inference.py @@ -196,7 +196,6 @@ class TTS(): input_prompt_length = 0, load_from_artifact = False, - nar_len_prefix_length = 0, seed = None, out_path=None, @@ -272,7 +271,15 @@ class TTS(): # to-do: add in case for experimental.hf model with torch.autocast("cuda", dtype=self.dtype, enabled=self.amp): if model_len is not None: + # extra kwargs + duration_padding = sampling_kwargs.pop("duration_padding", 1) + nar_len_prefix_length = sampling_kwargs.pop("nar_len_prefix_length", 0) + len_list = model_len( text_list=[phns], proms_list=[prom], task_list=["len"], disable_tqdm=not tqdm, **{"max_duration": 5} ) # don't need more than that + + # add an additional X seconds + len_list = [ l + duration_padding * cfg.dataset.frames_per_second for l in len_list ] + kwargs = {} # nasty hardcode to load a reference file and have that as the input target if load_from_artifact and load_from_artifact.exists(): @@ -280,10 +287,9 @@ class TTS(): phns = torch.tensor( cfg.tokenizer.encode( artifact["metadata"]["phonemes"] ) ).to(dtype=torch.uint8, device=self.device) resp = torch.from_numpy(artifact["codes"].astype(np.int16))[0, :, :].t().to(dtype=torch.int16, device=self.device) - prom = resp[:75*3, :] len_list = [ resp.shape[0] ] - kwargs["resps_list"] = [ resp[:, :1] ] + kwargs["resps_list"] = [ resp[:, 0] ] # kludge experiment elif nar_len_prefix_length > 0: resps_list = model_nar( diff --git a/vall_e/webui.py b/vall_e/webui.py index 9772760..6fb666b 100644 --- a/vall_e/webui.py +++ b/vall_e/webui.py @@ -236,6 +236,7 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ): parser.add_argument("--refine-on-stop", action="store_true") parser.add_argument("--denoise-start", type=float, default=0.0) parser.add_argument("--cfg-strength", type=float, default=kwargs['cfg-strength']) + parser.add_argument("--cfg-rescale", type=float, default=kwargs['cfg-rescale']) args, unknown = parser.parse_known_args() if is_windows: @@ -284,6 +285,7 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ): input_prompt_prefix=args.input_prompt_prefix, input_prompt_length=args.input_prompt_length, cfg_strength=args.cfg_strength, + cfg_rescale=args.cfg_rescale, ) with timer("Inferenced in", callback=lambda msg: gr.Info( msg )) as t: @@ -425,21 +427,23 @@ with ui: with gr.Column(scale=7): with gr.Tab("Basic Settings"): with gr.Row(): - layout["inference_tts"]["inputs"]["max-duration"] = gr.Slider(value=12, minimum=1, maximum=32, step=0.1, label="Maximum Seconds", info="Limits how many steps to perform in the AR pass.") + layout["inference_tts"]["inputs"]["max-duration"] = gr.Slider(value=12, minimum=1, maximum=32, step=0.1, label="Maximum Duration", info="Limits how many steps to perform in the AR pass.") + layout["inference_tts"]["inputs"]["max-steps"] = gr.Slider(value=50, minimum=1, maximum=200, step=1, label="Max Steps (NAR-len)", info="Limits how many steps to perform in the NAR-len (demask) pass.") layout["inference_tts"]["inputs"]["input-prompt-length"] = gr.Slider(value=5.0, minimum=0.0, maximum=12.0, step=0.05, label="Input Prompt Repeat/Trim Length", info="Repeats and trims the input prompt down to X seconds. Set 0 to disable.") with gr.Row(): - layout["inference_tts"]["inputs"]["ar-temperature"] = gr.Slider(value=1.0, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR. (0 to greedy* sample)") + layout["inference_tts"]["inputs"]["ar-temperature"] = gr.Slider(value=1.0, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR/NAR-len)", info="Modifies the randomness from the samples in the AR/NAR-len. (0 to greedy* sample)") layout["inference_tts"]["inputs"]["nar-temperature"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (NAR)", info="Modifies the randomness from the samples in the NAR. (0 to greedy sample)") - with gr.Row(): - layout["inference_tts"]["inputs"]["cfg-strength"] = gr.Slider(value=1.0, minimum=0.0, maximum=14.0, step=0.05, label="CFG Strength", info="Classifier Free Guidance scale") - layout["inference_tts"]["inputs"]["language"] = gr.Dropdown(choices=get_languages(), label="Language", value="en") layout["inference_tts"]["inputs"]["modality"] = gr.Dropdown(value="Auto", choices=["Auto", "AR+NAR", "NAR-len"], label="Modality", info="Whether to inference with the AR+NAR or through the NAR-len.") + with gr.Row(): + layout["inference_tts"]["inputs"]["cfg-strength"] = gr.Slider(value=1.0, minimum=0.0, maximum=14.0, step=0.05, label="CFG Strength", info="Classifier Free Guidance scale (AR needs 1, NAR-len needs 3).") + layout["inference_tts"]["inputs"]["cfg-rescale"] = gr.Slider(value=0.75, minimum=0.0, maximum=1.0, step=0.05, label="CFG Rescale (Phi)", info="Factor when rescaling for Classifier Free Guidance (0 to disable).") + layout["inference_tts"]["inputs"]["language"] = gr.Dropdown(choices=get_languages(), label="Language", value="en") with gr.Tab("Sampler Settings"): with gr.Row(): layout["inference_tts"]["inputs"]["top-p"] = gr.Slider(value=1.0, minimum=0.0, maximum=1.0, step=0.05, label="Top P", info=r"Limits the samples that are outside the top P% of probabilities.") layout["inference_tts"]["inputs"]["top-k"] = gr.Slider(value=0, minimum=0, maximum=1024, step=1, label="Top K", info="Limits the samples to the top K of probabilities.") layout["inference_tts"]["inputs"]["top-no"] = gr.Slider(value=0, minimum=0, maximum=2, step=0.05, label="Top-nσ", info="Performs top-nσ logits processing.") - layout["inference_tts"]["inputs"]["min-p"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.0, step=0.05, label="Min P") + layout["inference_tts"]["inputs"]["min-p"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.0, step=0.05, label="Min P", info="Filter out logits lower than this value.") with gr.Row(): layout["inference_tts"]["inputs"]["repetition-penalty"] = gr.Slider(value=1.0, minimum=0.0, maximum=5.0, step=0.05, label="Repetition Penalty", info="Incurs a penalty to tokens based on how often they appear in a sequence.") layout["inference_tts"]["inputs"]["repetition-penalty-decay"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty Length Decay", info="Modifies the reptition penalty based on how far back in time the token appeared in the sequence.") @@ -453,19 +457,16 @@ with ui: layout["inference_tts"]["inputs"]["dry-allowed-length"] = gr.Slider(value=2, minimum=0, maximum=75, step=1, label="Allowed Length", info="The maximimum length a token can be to perform DRY penalty with.") with gr.Tab("Experimental Settings", visible=cfg.experimental): with gr.Row(): - layout["inference_tts"]["inputs"]["max-steps"] = gr.Slider(value=25, minimum=1, maximum=500, step=1, label="Max NAR Steps", info="Limits how many steps to perform in the NAR (demask) pass.") layout["inference_tts"]["inputs"]["max-levels"] = gr.Slider(value=7, minimum=0, maximum=7, step=1, label="Max NAR Levels", info="Limits how many steps to perform in the NAR pass.") - with gr.Row(): - layout["inference_tts"]["inputs"]["input-prompt-prefix"] = gr.Checkbox(label="Input Prompt as Prefix", info="Treats the input prompt clip as the prefix of the generated sequence.") + layout["inference_tts"]["inputs"]["beam-width"] = gr.Slider(value=0, minimum=0, maximum=32, step=1, label="Beam Width", info="Number of branches to search through for beam search sampling.") layout["inference_tts"]["inputs"]["prefix-silence"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.0, step=0.05, label="Silence Prefix Duration", info="Amount of silence to prefix to the output response before beginning inference.") with gr.Row(): - layout["inference_tts"]["inputs"]["beam-width"] = gr.Slider(value=0, minimum=0, maximum=32, step=1, label="Beam Width", info="Number of branches to search through for beam search sampling.") + layout["inference_tts"]["inputs"]["input-prompt-prefix"] = gr.Checkbox(label="Input Prompt as Prefix", info="Treats the input prompt clip as the prefix of the generated sequence.") layout["inference_tts"]["inputs"]["dynamic-sampling"] = gr.Checkbox(label="Dynamic Temperature", info="Dynamically adjusts the temperature based on the highest confident predicted token per sampling step.") layout["inference_tts"]["inputs"]["entropix-sampling"] = gr.Checkbox(label="Entropix Sampling", info="Dynamically samples based on entropy/varentropy values from the logits / attention scores.") - with gr.Row(): - layout["inference_tts"]["inputs"]["layer-skip"] = gr.Checkbox(label="Layer Skip", info="Performs self-speculative early exit 'sampling'") layout["inference_tts"]["inputs"]["refine-on-stop"] = gr.Checkbox(label="Refine on ", info="Uses the last step's logits for the AR sequence instead.") - with gr.Row(): + with gr.Row(visible=False): + layout["inference_tts"]["inputs"]["layer-skip"] = gr.Checkbox(label="Layer Skip", info="Performs self-speculative early exit 'sampling'") layout["inference_tts"]["inputs"]["layer-skip-exit-layer"] = gr.Slider(value=11, minimum=0, maximum=11, step=1, label="Layer Skip Exit Layer", info="Maximum model layer to exit early from.") layout["inference_tts"]["inputs"]["layer-skip-entropy-threshold"] = gr.Slider(value=0.1, minimum=0, maximum=1.0, step=0.01, label="Layer Skip Entropy Threshold", info="Entropy threshold for early-exit") layout["inference_tts"]["inputs"]["layer-skip-varentropy-threshold"] = gr.Slider(value=0.1, minimum=0, maximum=1.0, step=0.01, label="Layer Skip Varentropy Threshold", info="Varentropy threshold for early-exit")