From bc30026377f14c3e5b6b428621bb32282a1776ba Mon Sep 17 00:00:00 2001 From: mrq Date: Sat, 9 Sep 2023 16:51:36 -0500 Subject: [PATCH] added advanced sampler parameters to the web UI --- vall_e/webui.py | 78 ++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 61 insertions(+), 17 deletions(-) diff --git a/vall_e/webui.py b/vall_e/webui.py index b0bb4bd..29fc216 100644 --- a/vall_e/webui.py +++ b/vall_e/webui.py @@ -1,6 +1,7 @@ import os import re import argparse +import random import tempfile import functools @@ -56,14 +57,14 @@ def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ): parser = argparse.ArgumentParser(allow_abbrev=False) parser.add_argument("--text", type=str, default=kwargs["text"]) parser.add_argument("--references", type=str, default=kwargs["reference"]) - parser.add_argument("--max-ar-steps", type=int, default=kwargs["steps"]) + parser.add_argument("--max-ar-steps", type=int, default=int(kwargs["max-seconds"]*75)) parser.add_argument("--ar-temp", type=float, default=kwargs["ar-temp"]) parser.add_argument("--nar-temp", type=float, default=kwargs["nar-temp"]) - parser.add_argument("--top-p", type=float, default=1.0) - parser.add_argument("--top-k", type=int, default=0) - parser.add_argument("--repetition-penalty", type=float, default=1.0) - parser.add_argument("--repetition-penalty-decay", type=float, default=0.0) - parser.add_argument("--length-penalty", type=float, default=0.0) + parser.add_argument("--top-p", type=float, default=kwargs["top-p"]) + parser.add_argument("--top-k", type=int, default=kwargs["top-k"]) + parser.add_argument("--repetition-penalty", type=float, default=kwargs["repetition-penalty"]) + parser.add_argument("--repetition-penalty-decay", type=float, default=kwargs["repetition-penalty-decay"]) + parser.add_argument("--length-penalty", type=float, default=kwargs["length-penalty"]) args, unknown = parser.parse_known_args() tmp = tempfile.NamedTemporaryFile(suffix='.wav') @@ -86,25 +87,68 @@ def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ): wav = wav.squeeze(0).cpu().numpy() return (sr, wav) +def get_random_prompt(): + harvard_sentences=[ + "The birch canoe slid on the smooth planks.", + "Glue the sheet to the dark blue background.", + "It's easy to tell the depth of a well.", + "These days a chicken leg is a rare dish.", + "Rice is often served in round bowls.", + "The juice of lemons makes fine punch.", + "The box was thrown beside the parked truck.", + "The hogs were fed chopped corn and garbage.", + "Four hours of steady work faced us.", + "A large size in stockings is hard to sell.", + "The boy was there when the sun rose.", + "A rod is used to catch pink salmon.", + "The source of the huge river is the clear spring.", + "Kick the ball straight and follow through.", + "Help the woman get back to her feet.", + "A pot of tea helps to pass the evening.", + "Smoky fires lack flame and heat.", + "The soft cushion broke the man's fall.", + "The salt breeze came across from the sea.", + "The girl at the booth sold fifty bonds.", + "The small pup gnawed a hole in the sock.", + "The fish twisted and turned on the bent hook.", + "Press the pants and sew a button on the vest.", + "The swan dive was far short of perfect.", + "The beauty of the view stunned the young boy.", + "Two blue fish swam in the tank.", + "Her purse was full of useless trash.", + "The colt reared and threw the tall rider.", + "It snowed, rained, and hailed the same morning.", + "Read verse out loud for pleasure.", + ] + return random.choice(harvard_sentences) + ui = gr.Blocks() with ui: with gr.Tab("Inference"): with gr.Row(): - with gr.Column(): - layout["inference"]["inputs"]["text"] = gr.Textbox(lines=4, value="Your prompt here", label="Input Prompt") + with gr.Column(scale=8): + layout["inference"]["inputs"]["text"] = gr.Textbox(lines=5, value=get_random_prompt, label="Input Prompt") with gr.Row(): - with gr.Column(): - layout["inference"]["inputs"]["reference"] = gr.Audio(label="Audio Input", source="upload", type="filepath") - with gr.Column(): - layout["inference"]["inputs"]["steps"] = gr.Slider(value=450, minimum=2, maximum=1024, step=1, label="Steps") - layout["inference"]["inputs"]["ar-temp"] = gr.Slider(value=0.95, minimum=0.0, maximum=1.2, step=0.05, label="Temperature (AR)") - layout["inference"]["inputs"]["nar-temp"] = gr.Slider(value=0.25, minimum=0.0, maximum=1.2, step=0.05, label="Temperature (NAR)") - with gr.Column(): - layout["inference"]["buttons"]["start"] = gr.Button(value="Inference") + with gr.Column(scale=1): + layout["inference"]["inputs"]["reference"] = gr.Audio(label="Audio Input", source="upload", type="filepath", info="Reference audio for TTS") # layout["inference"]["stop"] = gr.Button(value="Stop") layout["inference"]["outputs"]["output"] = gr.Audio(label="Output") + layout["inference"]["buttons"]["inference"] = gr.Button(value="Inference") + with gr.Column(scale=7): + layout["inference"]["inputs"]["max-seconds"] = gr.Slider(value=6, minimum=1, maximum=32, step=0.1, label="Maximum Seconds", info="This sets a limit of how many steps to perform in the AR pass.") + with gr.Row(): + layout["inference"]["inputs"]["ar-temp"] = gr.Slider(value=0.95, minimum=0.0, maximum=1.2, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR.") + layout["inference"]["inputs"]["nar-temp"] = gr.Slider(value=0.25, minimum=0.0, maximum=1.2, step=0.05, label="Temperature (NAR)", info="Modifies the randomness from the samples in the NAR.") - layout["inference"]["buttons"]["start"].click( + with gr.Row(): + layout["inference"]["inputs"]["top-p"] = gr.Slider(value=1.0, minimum=0.0, maximum=1.0, step=0.05, label="Top P", info="Limits the samples that are outside the top P%% of probabilities.") + layout["inference"]["inputs"]["top-k"] = gr.Slider(value=0, minimum=0, maximum=1024, step=1, label="Top K", info="Limits the samples to the top K of probabilities.") + with gr.Row(): + layout["inference"]["inputs"]["repetition-penalty"] = gr.Slider(value=1.0, minimum=-4.0, maximum=4.0, step=0.05, label="Repetition Penalty", info="Incurs a penalty to tokens based on how often they appear in a sequence.") + layout["inference"]["inputs"]["repetition-penalty-decay"] = gr.Slider(value=0.0, minimum=-4.0, maximum=4.0, step=0.05, label="Repetition Penalty Length Decay", info="Modifies the reptition penalty based on how far back in time the token appeared in the sequence.") + layout["inference"]["inputs"]["length-penalty"] = gr.Slider(value=0.0, minimum=-40.0, maximum=4.0, step=0.05, label="Length Penalty", info="(AR only) Modifies the probability of a stop token based on the current length of the sequence.") + + layout["inference"]["buttons"]["inference"].click( fn=do_inference, inputs=[ x for x in layout["inference"]["inputs"].values() if x is not None], outputs=[ x for x in layout["inference"]["outputs"].values() if x is not None]