added advanced sampler parameters to the web UI

This commit is contained in:
mrq 2023-09-09 16:51:36 -05:00
parent 5ac119a6e7
commit bc30026377

View File

@ -1,6 +1,7 @@
import os
import re
import argparse
import random
import tempfile
import functools
@ -56,14 +57,14 @@ def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
parser = argparse.ArgumentParser(allow_abbrev=False)
parser.add_argument("--text", type=str, default=kwargs["text"])
parser.add_argument("--references", type=str, default=kwargs["reference"])
parser.add_argument("--max-ar-steps", type=int, default=kwargs["steps"])
parser.add_argument("--max-ar-steps", type=int, default=int(kwargs["max-seconds"]*75))
parser.add_argument("--ar-temp", type=float, default=kwargs["ar-temp"])
parser.add_argument("--nar-temp", type=float, default=kwargs["nar-temp"])
parser.add_argument("--top-p", type=float, default=1.0)
parser.add_argument("--top-k", type=int, default=0)
parser.add_argument("--repetition-penalty", type=float, default=1.0)
parser.add_argument("--repetition-penalty-decay", type=float, default=0.0)
parser.add_argument("--length-penalty", type=float, default=0.0)
parser.add_argument("--top-p", type=float, default=kwargs["top-p"])
parser.add_argument("--top-k", type=int, default=kwargs["top-k"])
parser.add_argument("--repetition-penalty", type=float, default=kwargs["repetition-penalty"])
parser.add_argument("--repetition-penalty-decay", type=float, default=kwargs["repetition-penalty-decay"])
parser.add_argument("--length-penalty", type=float, default=kwargs["length-penalty"])
args, unknown = parser.parse_known_args()
tmp = tempfile.NamedTemporaryFile(suffix='.wav')
@ -86,25 +87,68 @@ def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
wav = wav.squeeze(0).cpu().numpy()
return (sr, wav)
def get_random_prompt():
harvard_sentences=[
"The birch canoe slid on the smooth planks.",
"Glue the sheet to the dark blue background.",
"It's easy to tell the depth of a well.",
"These days a chicken leg is a rare dish.",
"Rice is often served in round bowls.",
"The juice of lemons makes fine punch.",
"The box was thrown beside the parked truck.",
"The hogs were fed chopped corn and garbage.",
"Four hours of steady work faced us.",
"A large size in stockings is hard to sell.",
"The boy was there when the sun rose.",
"A rod is used to catch pink salmon.",
"The source of the huge river is the clear spring.",
"Kick the ball straight and follow through.",
"Help the woman get back to her feet.",
"A pot of tea helps to pass the evening.",
"Smoky fires lack flame and heat.",
"The soft cushion broke the man's fall.",
"The salt breeze came across from the sea.",
"The girl at the booth sold fifty bonds.",
"The small pup gnawed a hole in the sock.",
"The fish twisted and turned on the bent hook.",
"Press the pants and sew a button on the vest.",
"The swan dive was far short of perfect.",
"The beauty of the view stunned the young boy.",
"Two blue fish swam in the tank.",
"Her purse was full of useless trash.",
"The colt reared and threw the tall rider.",
"It snowed, rained, and hailed the same morning.",
"Read verse out loud for pleasure.",
]
return random.choice(harvard_sentences)
ui = gr.Blocks()
with ui:
with gr.Tab("Inference"):
with gr.Row():
with gr.Column():
layout["inference"]["inputs"]["text"] = gr.Textbox(lines=4, value="Your prompt here", label="Input Prompt")
with gr.Column(scale=8):
layout["inference"]["inputs"]["text"] = gr.Textbox(lines=5, value=get_random_prompt, label="Input Prompt")
with gr.Row():
with gr.Column():
layout["inference"]["inputs"]["reference"] = gr.Audio(label="Audio Input", source="upload", type="filepath")
with gr.Column():
layout["inference"]["inputs"]["steps"] = gr.Slider(value=450, minimum=2, maximum=1024, step=1, label="Steps")
layout["inference"]["inputs"]["ar-temp"] = gr.Slider(value=0.95, minimum=0.0, maximum=1.2, step=0.05, label="Temperature (AR)")
layout["inference"]["inputs"]["nar-temp"] = gr.Slider(value=0.25, minimum=0.0, maximum=1.2, step=0.05, label="Temperature (NAR)")
with gr.Column():
layout["inference"]["buttons"]["start"] = gr.Button(value="Inference")
with gr.Column(scale=1):
layout["inference"]["inputs"]["reference"] = gr.Audio(label="Audio Input", source="upload", type="filepath", info="Reference audio for TTS")
# layout["inference"]["stop"] = gr.Button(value="Stop")
layout["inference"]["outputs"]["output"] = gr.Audio(label="Output")
layout["inference"]["buttons"]["inference"] = gr.Button(value="Inference")
with gr.Column(scale=7):
layout["inference"]["inputs"]["max-seconds"] = gr.Slider(value=6, minimum=1, maximum=32, step=0.1, label="Maximum Seconds", info="This sets a limit of how many steps to perform in the AR pass.")
with gr.Row():
layout["inference"]["inputs"]["ar-temp"] = gr.Slider(value=0.95, minimum=0.0, maximum=1.2, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR.")
layout["inference"]["inputs"]["nar-temp"] = gr.Slider(value=0.25, minimum=0.0, maximum=1.2, step=0.05, label="Temperature (NAR)", info="Modifies the randomness from the samples in the NAR.")
layout["inference"]["buttons"]["start"].click(
with gr.Row():
layout["inference"]["inputs"]["top-p"] = gr.Slider(value=1.0, minimum=0.0, maximum=1.0, step=0.05, label="Top P", info="Limits the samples that are outside the top P%% of probabilities.")
layout["inference"]["inputs"]["top-k"] = gr.Slider(value=0, minimum=0, maximum=1024, step=1, label="Top K", info="Limits the samples to the top K of probabilities.")
with gr.Row():
layout["inference"]["inputs"]["repetition-penalty"] = gr.Slider(value=1.0, minimum=-4.0, maximum=4.0, step=0.05, label="Repetition Penalty", info="Incurs a penalty to tokens based on how often they appear in a sequence.")
layout["inference"]["inputs"]["repetition-penalty-decay"] = gr.Slider(value=0.0, minimum=-4.0, maximum=4.0, step=0.05, label="Repetition Penalty Length Decay", info="Modifies the reptition penalty based on how far back in time the token appeared in the sequence.")
layout["inference"]["inputs"]["length-penalty"] = gr.Slider(value=0.0, minimum=-40.0, maximum=4.0, step=0.05, label="Length Penalty", info="(AR only) Modifies the probability of a stop token based on the current length of the sequence.")
layout["inference"]["buttons"]["inference"].click(
fn=do_inference,
inputs=[ x for x in layout["inference"]["inputs"].values() if x is not None],
outputs=[ x for x in layout["inference"]["outputs"].values() if x is not None]