From 7c9144ff22b4139de4f701ca8506996b5bffcb17 Mon Sep 17 00:00:00 2001 From: mrq Date: Tue, 18 Jun 2024 21:03:25 -0500 Subject: [PATCH] working webui --- tortoise_tts/webui.py | 45 ++++++++++++++++++++++++++++--------------- 1 file changed, 29 insertions(+), 16 deletions(-) diff --git a/tortoise_tts/webui.py b/tortoise_tts/webui.py index 5014851..33984db 100644 --- a/tortoise_tts/webui.py +++ b/tortoise_tts/webui.py @@ -64,34 +64,43 @@ def init_tts(restart=False): @gradio_wrapper(inputs=layout["inference"]["inputs"].keys()) def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ): + """ if kwargs.pop("dynamic-sampling", False): kwargs['min-ar-temp'] = 0.85 if kwargs['ar-temp'] > 0.85 else 0.0 kwargs['min-nar-temp'] = 0.85 if kwargs['nar-temp'] > 0.85 else 0.0 # should probably disable it for the NAR else: kwargs['min-ar-temp'] = -1 kwargs['min-nar-temp'] = -1 + """ parser = argparse.ArgumentParser(allow_abbrev=False) # I'm very sure I can procedurally generate this list parser.add_argument("--text", type=str, default=kwargs["text"]) parser.add_argument("--references", type=str, default=kwargs["reference"]) + parser.add_argument("--max-ar-steps", type=int, default=int(kwargs["max-ar-steps"])) + parser.add_argument("--max-diffusion-steps", type=int, default=int(kwargs["max-diffusion-steps"])) + """ parser.add_argument("--language", type=str, default="en") parser.add_argument("--input-prompt-length", type=float, default=kwargs["input-prompt-length"]) - parser.add_argument("--max-ar-steps", type=int, default=int(kwargs["max-seconds"]*cfg.dataset.frames_per_second)) parser.add_argument("--max-ar-context", type=int, default=int(kwargs["max-seconds-context"]*cfg.dataset.frames_per_second)) parser.add_argument("--max-nar-levels", type=int, default=kwargs["max-nar-levels"]) + """ parser.add_argument("--ar-temp", type=float, default=kwargs["ar-temp"]) - parser.add_argument("--nar-temp", type=float, default=kwargs["nar-temp"]) + parser.add_argument("--diffusion-temp", type=float, default=kwargs["diffusion-temp"]) + """ parser.add_argument("--min-ar-temp", type=float, default=kwargs["min-ar-temp"]) - parser.add_argument("--min-nar-temp", type=float, default=kwargs["min-nar-temp"]) + parser.add_argument("--min-diffusion-temp", type=float, default=kwargs["min-diffusion-temp"]) + """ parser.add_argument("--top-p", type=float, default=kwargs["top-p"]) parser.add_argument("--top-k", type=int, default=kwargs["top-k"]) parser.add_argument("--repetition-penalty", type=float, default=kwargs["repetition-penalty"]) - parser.add_argument("--repetition-penalty-decay", type=float, default=kwargs["repetition-penalty-decay"]) parser.add_argument("--length-penalty", type=float, default=kwargs["length-penalty"]) parser.add_argument("--beam-width", type=int, default=kwargs["beam-width"]) + """ + parser.add_argument("--repetition-penalty-decay", type=float, default=kwargs["repetition-penalty-decay"]) parser.add_argument("--mirostat-tau", type=float, default=kwargs["mirostat-tau"]) parser.add_argument("--mirostat-eta", type=float, default=kwargs["mirostat-eta"]) + """ args, unknown = parser.parse_known_args() tmp = tempfile.NamedTemporaryFile(suffix='.wav') @@ -103,23 +112,19 @@ def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ): with timer() as t: wav, sr = tts.inference( text=args.text, - language=args.language, - references=[args.references.split(";")], + #language=args.language, + references=[args.references], out_path=tmp.name, max_ar_steps=args.max_ar_steps, - max_nar_levels=args.max_nar_levels, - input_prompt_length=args.input_prompt_length, + max_diffusion_steps=args.max_diffusion_steps, ar_temp=args.ar_temp, - nar_temp=args.nar_temp, - min_ar_temp=args.min_ar_temp, - min_nar_temp=args.min_nar_temp, + diffusion_temp=args.diffusion_temp, top_p=args.top_p, top_k=args.top_k, repetition_penalty=args.repetition_penalty, - repetition_penalty_decay=args.repetition_penalty_decay, + #repetition_penalty_decay=args.repetition_penalty_decay, length_penalty=args.length_penalty, - mirostat_tau=args.mirostat_tau, - mirostat_eta=args.mirostat_eta, + beam_width=args.beam_width, ) wav = wav.squeeze(0).cpu().numpy() @@ -207,16 +212,23 @@ with ui: layout["inference"]["outputs"]["output"] = gr.Audio(label="Output") layout["inference"]["buttons"]["inference"] = gr.Button(value="Inference") with gr.Column(scale=7): + """ with gr.Row(): layout["inference"]["inputs"]["max-seconds"] = gr.Slider(value=12, minimum=1, maximum=32, step=0.1, label="Maximum Seconds", info="Limits how many steps to perform in the AR pass.") layout["inference"]["inputs"]["max-nar-levels"] = gr.Slider(value=7, minimum=0, maximum=7, step=1, label="Max NAR Levels", info="Limits how many steps to perform in the NAR pass.") layout["inference"]["inputs"]["input-prompt-length"] = gr.Slider(value=3.0, minimum=0.0, maximum=12.0, step=0.05, label="Input Prompt Trim Length", info="Trims the input prompt down to X seconds. Set 0 to disable.") layout["inference"]["inputs"]["max-seconds-context"] = gr.Slider(value=0.0, minimum=0.0, maximum=12.0, step=0.05, label="Context Length", info="Amount of generated audio to keep in the context during inference, in seconds. Set 0 to disable.") + """ + with gr.Row(): + layout["inference"]["inputs"]["max-ar-steps"] = gr.Slider(value=500, minimum=16, maximum=1200, step=1, label="Maximum AR Steps", info="Limits how many steps to perform in the AR pass.") + layout["inference"]["inputs"]["max-diffusion-steps"] = gr.Slider(value=80, minimum=16, maximum=500, step=1, label="Maximum Diffusion Steps", info="Limits how many steps to perform in the Diffusion pass.") with gr.Row(): layout["inference"]["inputs"]["ar-temp"] = gr.Slider(value=0.95, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR. (0 to greedy sample)") - layout["inference"]["inputs"]["nar-temp"] = gr.Slider(value=0.01, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (NAR)", info="Modifies the randomness from the samples in the NAR. (0 to greedy sample)") + layout["inference"]["inputs"]["diffusion-temp"] = gr.Slider(value=0.01, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (NAR)", info="Modifies the randomness from the samples in the NAR. (0 to greedy sample)") + """ with gr.Row(): layout["inference"]["inputs"]["dynamic-sampling"] = gr.Checkbox(label="Dynamic Temperature", info="Dynamically adjusts the temperature based on the highest confident predicted token per sampling step.") + """ with gr.Row(): layout["inference"]["inputs"]["top-p"] = gr.Slider(value=1.0, minimum=0.0, maximum=1.0, step=0.05, label="Top P", info=r"Limits the samples that are outside the top P% of probabilities.") @@ -226,10 +238,11 @@ with ui: layout["inference"]["inputs"]["repetition-penalty"] = gr.Slider(value=1.0, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty", info="Incurs a penalty to tokens based on how often they appear in a sequence.") layout["inference"]["inputs"]["repetition-penalty-decay"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty Length Decay", info="Modifies the reptition penalty based on how far back in time the token appeared in the sequence.") layout["inference"]["inputs"]["length-penalty"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Length Penalty", info="(AR only) Modifies the probability of a stop token based on the current length of the sequence.") + """ with gr.Row(): layout["inference"]["inputs"]["mirostat-tau"] = gr.Slider(value=0.0, minimum=0.0, maximum=8.0, step=0.05, label="Mirostat τ (Tau)", info="The \"surprise\" value when performing mirostat sampling. 0 to disable.") layout["inference"]["inputs"]["mirostat-eta"] = gr.Slider(value=0.0, minimum=0.0, maximum=2.0, step=0.05, label="Mirostat η (Eta)", info="The \"learning rate\" during mirostat sampling applied to the maximum surprise.") - + """ layout["inference"]["buttons"]["inference"].click( fn=do_inference, inputs=[ x for x in layout["inference"]["inputs"].values() if x is not None],