added advanced sampler parameters to the web UI
This commit is contained in:
parent
5ac119a6e7
commit
bc30026377
|
@ -1,6 +1,7 @@
|
|||
import os
|
||||
import re
|
||||
import argparse
|
||||
import random
|
||||
import tempfile
|
||||
import functools
|
||||
|
||||
|
@ -56,14 +57,14 @@ def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
|||
parser = argparse.ArgumentParser(allow_abbrev=False)
|
||||
parser.add_argument("--text", type=str, default=kwargs["text"])
|
||||
parser.add_argument("--references", type=str, default=kwargs["reference"])
|
||||
parser.add_argument("--max-ar-steps", type=int, default=kwargs["steps"])
|
||||
parser.add_argument("--max-ar-steps", type=int, default=int(kwargs["max-seconds"]*75))
|
||||
parser.add_argument("--ar-temp", type=float, default=kwargs["ar-temp"])
|
||||
parser.add_argument("--nar-temp", type=float, default=kwargs["nar-temp"])
|
||||
parser.add_argument("--top-p", type=float, default=1.0)
|
||||
parser.add_argument("--top-k", type=int, default=0)
|
||||
parser.add_argument("--repetition-penalty", type=float, default=1.0)
|
||||
parser.add_argument("--repetition-penalty-decay", type=float, default=0.0)
|
||||
parser.add_argument("--length-penalty", type=float, default=0.0)
|
||||
parser.add_argument("--top-p", type=float, default=kwargs["top-p"])
|
||||
parser.add_argument("--top-k", type=int, default=kwargs["top-k"])
|
||||
parser.add_argument("--repetition-penalty", type=float, default=kwargs["repetition-penalty"])
|
||||
parser.add_argument("--repetition-penalty-decay", type=float, default=kwargs["repetition-penalty-decay"])
|
||||
parser.add_argument("--length-penalty", type=float, default=kwargs["length-penalty"])
|
||||
args, unknown = parser.parse_known_args()
|
||||
|
||||
tmp = tempfile.NamedTemporaryFile(suffix='.wav')
|
||||
|
@ -86,25 +87,68 @@ def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
|||
wav = wav.squeeze(0).cpu().numpy()
|
||||
return (sr, wav)
|
||||
|
||||
def get_random_prompt():
|
||||
harvard_sentences=[
|
||||
"The birch canoe slid on the smooth planks.",
|
||||
"Glue the sheet to the dark blue background.",
|
||||
"It's easy to tell the depth of a well.",
|
||||
"These days a chicken leg is a rare dish.",
|
||||
"Rice is often served in round bowls.",
|
||||
"The juice of lemons makes fine punch.",
|
||||
"The box was thrown beside the parked truck.",
|
||||
"The hogs were fed chopped corn and garbage.",
|
||||
"Four hours of steady work faced us.",
|
||||
"A large size in stockings is hard to sell.",
|
||||
"The boy was there when the sun rose.",
|
||||
"A rod is used to catch pink salmon.",
|
||||
"The source of the huge river is the clear spring.",
|
||||
"Kick the ball straight and follow through.",
|
||||
"Help the woman get back to her feet.",
|
||||
"A pot of tea helps to pass the evening.",
|
||||
"Smoky fires lack flame and heat.",
|
||||
"The soft cushion broke the man's fall.",
|
||||
"The salt breeze came across from the sea.",
|
||||
"The girl at the booth sold fifty bonds.",
|
||||
"The small pup gnawed a hole in the sock.",
|
||||
"The fish twisted and turned on the bent hook.",
|
||||
"Press the pants and sew a button on the vest.",
|
||||
"The swan dive was far short of perfect.",
|
||||
"The beauty of the view stunned the young boy.",
|
||||
"Two blue fish swam in the tank.",
|
||||
"Her purse was full of useless trash.",
|
||||
"The colt reared and threw the tall rider.",
|
||||
"It snowed, rained, and hailed the same morning.",
|
||||
"Read verse out loud for pleasure.",
|
||||
]
|
||||
return random.choice(harvard_sentences)
|
||||
|
||||
ui = gr.Blocks()
|
||||
with ui:
|
||||
with gr.Tab("Inference"):
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
layout["inference"]["inputs"]["text"] = gr.Textbox(lines=4, value="Your prompt here", label="Input Prompt")
|
||||
with gr.Column(scale=8):
|
||||
layout["inference"]["inputs"]["text"] = gr.Textbox(lines=5, value=get_random_prompt, label="Input Prompt")
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
layout["inference"]["inputs"]["reference"] = gr.Audio(label="Audio Input", source="upload", type="filepath")
|
||||
with gr.Column():
|
||||
layout["inference"]["inputs"]["steps"] = gr.Slider(value=450, minimum=2, maximum=1024, step=1, label="Steps")
|
||||
layout["inference"]["inputs"]["ar-temp"] = gr.Slider(value=0.95, minimum=0.0, maximum=1.2, step=0.05, label="Temperature (AR)")
|
||||
layout["inference"]["inputs"]["nar-temp"] = gr.Slider(value=0.25, minimum=0.0, maximum=1.2, step=0.05, label="Temperature (NAR)")
|
||||
with gr.Column():
|
||||
layout["inference"]["buttons"]["start"] = gr.Button(value="Inference")
|
||||
with gr.Column(scale=1):
|
||||
layout["inference"]["inputs"]["reference"] = gr.Audio(label="Audio Input", source="upload", type="filepath", info="Reference audio for TTS")
|
||||
# layout["inference"]["stop"] = gr.Button(value="Stop")
|
||||
layout["inference"]["outputs"]["output"] = gr.Audio(label="Output")
|
||||
layout["inference"]["buttons"]["inference"] = gr.Button(value="Inference")
|
||||
with gr.Column(scale=7):
|
||||
layout["inference"]["inputs"]["max-seconds"] = gr.Slider(value=6, minimum=1, maximum=32, step=0.1, label="Maximum Seconds", info="This sets a limit of how many steps to perform in the AR pass.")
|
||||
with gr.Row():
|
||||
layout["inference"]["inputs"]["ar-temp"] = gr.Slider(value=0.95, minimum=0.0, maximum=1.2, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR.")
|
||||
layout["inference"]["inputs"]["nar-temp"] = gr.Slider(value=0.25, minimum=0.0, maximum=1.2, step=0.05, label="Temperature (NAR)", info="Modifies the randomness from the samples in the NAR.")
|
||||
|
||||
layout["inference"]["buttons"]["start"].click(
|
||||
with gr.Row():
|
||||
layout["inference"]["inputs"]["top-p"] = gr.Slider(value=1.0, minimum=0.0, maximum=1.0, step=0.05, label="Top P", info="Limits the samples that are outside the top P%% of probabilities.")
|
||||
layout["inference"]["inputs"]["top-k"] = gr.Slider(value=0, minimum=0, maximum=1024, step=1, label="Top K", info="Limits the samples to the top K of probabilities.")
|
||||
with gr.Row():
|
||||
layout["inference"]["inputs"]["repetition-penalty"] = gr.Slider(value=1.0, minimum=-4.0, maximum=4.0, step=0.05, label="Repetition Penalty", info="Incurs a penalty to tokens based on how often they appear in a sequence.")
|
||||
layout["inference"]["inputs"]["repetition-penalty-decay"] = gr.Slider(value=0.0, minimum=-4.0, maximum=4.0, step=0.05, label="Repetition Penalty Length Decay", info="Modifies the reptition penalty based on how far back in time the token appeared in the sequence.")
|
||||
layout["inference"]["inputs"]["length-penalty"] = gr.Slider(value=0.0, minimum=-40.0, maximum=4.0, step=0.05, label="Length Penalty", info="(AR only) Modifies the probability of a stop token based on the current length of the sequence.")
|
||||
|
||||
layout["inference"]["buttons"]["inference"].click(
|
||||
fn=do_inference,
|
||||
inputs=[ x for x in layout["inference"]["inputs"].values() if x is not None],
|
||||
outputs=[ x for x in layout["inference"]["outputs"].values() if x is not None]
|
||||
|
|
Loading…
Reference in New Issue
Block a user