vall-e/vall_e/webui.py

339 lines
15 KiB
Python

import os
import re
import argparse
import random
import tempfile
import functools
from datetime import datetime
import gradio as gr
from time import perf_counter
from pathlib import Path
from .inference import TTS, cfg
from .train import train
from .utils import get_devices
tts = None
layout = {}
layout["inference"] = {}
layout["training"] = {}
layout["settings"] = {}
for k in layout.keys():
layout[k]["inputs"] = { "progress": None }
layout[k]["outputs"] = {}
layout[k]["buttons"] = {}
# there's got to be a better way to go about this
def gradio_wrapper(inputs):
def decorated(fun):
@functools.wraps(fun)
def wrapped_function(*args, **kwargs):
for i, key in enumerate(inputs):
kwargs[key] = args[i]
try:
return fun(**kwargs)
except Exception as e:
raise gr.Error(str(e))
return wrapped_function
return decorated
class timer:
def __init__(self, msg="Elapsed time:"):
self.msg = msg
def __enter__(self):
self.start = perf_counter()
return self
def __exit__(self, type, value, traceback):
msg = f'{self.msg} {(perf_counter() - self.start):.3f}s'
gr.Info(msg)
print(f'[{datetime.now().isoformat()}] {msg}')
# returns a list of models, assuming the models are placed under ./training/ or ./models/
def get_model_paths( paths=[Path("./training/"), Path("./models/")] ):
yamls = []
for path in paths:
if not path.exists():
continue
for yaml in path.glob("**/*.yaml"):
if "/logs/" in str(yaml):
continue
yamls.append( yaml )
return yamls
def get_dtypes():
return ["float32", "float16", "bfloat16", "float8_e5m2", "float8_e4m3fn", "auto"]
#@gradio_wrapper(inputs=layout["settings"]["inputs"].keys())
def load_model( yaml, device, dtype ):
gr.Info(f"Loading: {yaml}")
try:
init_tts( yaml=Path(yaml), restart=True )
except Exception as e:
raise gr.Error(e)
gr.Info(f"Loaded model")
def init_tts(yaml=None, restart=False, device="cuda", dtype="auto"):
global tts
if tts is not None:
if not restart:
return tts
del tts
tts = None
parser = argparse.ArgumentParser(allow_abbrev=False)
parser.add_argument("--yaml", type=Path, default=os.environ.get('VALLE_YAML', yaml)) # os environ so it can be specified in a HuggingFace Space too
parser.add_argument("--device", type=str, default=device)
parser.add_argument("--amp", action="store_true")
parser.add_argument("--dtype", type=str, default=dtype)
args, unknown = parser.parse_known_args()
tts = TTS( config=args.yaml if yaml is None else yaml, device=args.device, dtype=args.dtype if args.dtype != "auto" else None, amp=args.amp )
return tts
@gradio_wrapper(inputs=layout["inference"]["inputs"].keys())
def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
if not cfg.yaml_path:
raise Exception("No YAML loaded.")
if kwargs.pop("dynamic-sampling", False):
kwargs['min-ar-temp'] = 0.85 if kwargs['ar-temp'] > 0.85 else 0.0
kwargs['min-nar-temp'] = 0.85 if kwargs['nar-temp'] > 0.85 else 0.0 # should probably disable it for the NAR
else:
kwargs['min-ar-temp'] = -1
kwargs['min-nar-temp'] = -1
parser = argparse.ArgumentParser(allow_abbrev=False)
# I'm very sure I can procedurally generate this list
parser.add_argument("--text", type=str, default=kwargs["text"])
parser.add_argument("--references", type=str, default=kwargs["reference"])
parser.add_argument("--language", type=str, default="en")
parser.add_argument("--input-prompt-length", type=float, default=kwargs["input-prompt-length"])
parser.add_argument("--max-ar-steps", type=int, default=int(kwargs["max-seconds"]*cfg.dataset.frames_per_second))
parser.add_argument("--max-nar-levels", type=int, default=0), # kwargs["max-nar-levels"])
parser.add_argument("--ar-temp", type=float, default=kwargs["ar-temp"])
parser.add_argument("--nar-temp", type=float, default=kwargs["nar-temp"])
parser.add_argument("--min-ar-temp", type=float, default=kwargs["min-ar-temp"])
parser.add_argument("--min-nar-temp", type=float, default=kwargs["min-nar-temp"])
parser.add_argument("--top-p", type=float, default=kwargs["top-p"])
parser.add_argument("--top-k", type=int, default=kwargs["top-k"])
parser.add_argument("--repetition-penalty", type=float, default=kwargs["repetition-penalty"])
parser.add_argument("--repetition-penalty-decay", type=float, default=kwargs["repetition-penalty-decay"])
parser.add_argument("--length-penalty", type=float, default=kwargs["length-penalty"])
parser.add_argument("--beam-width", type=int, default=kwargs["beam-width"])
parser.add_argument("--mirostat-tau", type=float, default=kwargs["mirostat-tau"])
parser.add_argument("--mirostat-eta", type=float, default=kwargs["mirostat-eta"])
parser.add_argument("--dry-multiplier", type=float, default=kwargs["dry-multiplier"])
parser.add_argument("--dry-base", type=float, default=kwargs["dry-base"])
parser.add_argument("--dry-allowed-length", type=int, default=kwargs["dry-allowed-length"])
args, unknown = parser.parse_known_args()
tmp = tempfile.NamedTemporaryFile(suffix='.wav')
"""
if not args.references:
raise Exception("No reference audio provided.")
"""
tts = init_tts()
gr.Info("Inferencing...")
with timer("Inferenced in") as t:
wav, sr = tts.inference(
text=args.text,
language=args.language,
references=[args.references.split(";")] if args.references is not None else [],
out_path=tmp.name,
max_ar_steps=args.max_ar_steps,
max_nar_levels=args.max_nar_levels,
input_prompt_length=args.input_prompt_length,
ar_temp=args.ar_temp,
nar_temp=args.nar_temp,
min_ar_temp=args.min_ar_temp,
min_nar_temp=args.min_nar_temp,
top_p=args.top_p,
top_k=args.top_k,
repetition_penalty=args.repetition_penalty,
repetition_penalty_decay=args.repetition_penalty_decay,
length_penalty=args.length_penalty,
mirostat_tau=args.mirostat_tau,
mirostat_eta=args.mirostat_eta,
dry_multiplier=args.dry_multiplier,
dry_base=args.dry_base,
dry_allowed_length=args.dry_allowed_length,
)
wav = wav.squeeze(0).cpu().numpy()
return (sr, wav)
"""
@gradio_wrapper(inputs=layout["training"]["inputs"].keys())
def do_training( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
while True:
metrics = next(it)
yield metrics
"""
def get_random_prompt():
harvard_sentences=[
"The birch canoe slid on the smooth planks.",
"Glue the sheet to the dark blue background.",
"It's easy to tell the depth of a well.",
"These days a chicken leg is a rare dish.",
"Rice is often served in round bowls.",
"The juice of lemons makes fine punch.",
"The box was thrown beside the parked truck.",
"The hogs were fed chopped corn and garbage.",
"Four hours of steady work faced us.",
"A large size in stockings is hard to sell.",
"The boy was there when the sun rose.",
"A rod is used to catch pink salmon.",
"The source of the huge river is the clear spring.",
"Kick the ball straight and follow through.",
"Help the woman get back to her feet.",
"A pot of tea helps to pass the evening.",
"Smoky fires lack flame and heat.",
"The soft cushion broke the man's fall.",
"The salt breeze came across from the sea.",
"The girl at the booth sold fifty bonds.",
"The small pup gnawed a hole in the sock.",
"The fish twisted and turned on the bent hook.",
"Press the pants and sew a button on the vest.",
"The swan dive was far short of perfect.",
"The beauty of the view stunned the young boy.",
"Two blue fish swam in the tank.",
"Her purse was full of useless trash.",
"The colt reared and threw the tall rider.",
"It snowed, rained, and hailed the same morning.",
"Read verse out loud for pleasure.",
]
return random.choice(harvard_sentences)
# setup args
parser = argparse.ArgumentParser(allow_abbrev=False)
parser.add_argument("--yaml", type=Path, default=os.environ.get('VALLE_YAML', None)) # os environ so it can be specified in a HuggingFace Space too
parser.add_argument("--listen", default=None, help="Path for Gradio to listen on")
parser.add_argument("--share", action="store_true")
parser.add_argument("--render_markdown", action="store_true", default="VALLE_YAML" in os.environ)
args, unknown = parser.parse_known_args()
args.listen_host = None
args.listen_port = None
args.listen_path = None
if args.listen:
try:
match = re.findall(r"^(?:(.+?):(\d+))?(\/.*?)?$", args.listen)[0]
args.listen_host = match[0] if match[0] != "" else "127.0.0.1"
args.listen_port = match[1] if match[1] != "" else None
args.listen_path = match[2] if match[2] != "" else "/"
except Exception as e:
pass
if args.listen_port is not None:
args.listen_port = int(args.listen_port)
if args.listen_port == 0:
args.listen_port = None
# setup gradio
ui = gr.Blocks()
with ui:
with gr.Tab("Inference"):
with gr.Row():
with gr.Column(scale=8):
layout["inference"]["inputs"]["text"] = gr.Textbox(lines=5, value=get_random_prompt, label="Input Prompt")
with gr.Row():
with gr.Column(scale=1):
layout["inference"]["inputs"]["reference"] = gr.Audio(label="Audio Input", sources=["upload"], type="filepath") #, info="Reference audio for TTS")
# layout["inference"]["stop"] = gr.Button(value="Stop")
layout["inference"]["outputs"]["output"] = gr.Audio(label="Output")
layout["inference"]["buttons"]["inference"] = gr.Button(value="Inference")
with gr.Column(scale=7):
with gr.Row():
layout["inference"]["inputs"]["max-seconds"] = gr.Slider(value=12, minimum=1, maximum=32, step=0.1, label="Maximum Seconds", info="Limits how many steps to perform in the AR pass.")
#layout["inference"]["inputs"]["max-nar-levels"] = gr.Slider(value=7, minimum=0, maximum=7, step=1, label="Max NAR Levels", info="Limits how many steps to perform in the NAR pass.")
layout["inference"]["inputs"]["input-prompt-length"] = gr.Slider(value=3.0, minimum=0.0, maximum=12.0, step=0.05, label="Input Prompt Trim Length", info="Trims the input prompt down to X seconds. Set 0 to disable.")
with gr.Row():
layout["inference"]["inputs"]["ar-temp"] = gr.Slider(value=0.95, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR. (0 to greedy sample)")
layout["inference"]["inputs"]["nar-temp"] = gr.Slider(value=0.01, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (NAR)", info="Modifies the randomness from the samples in the NAR. (0 to greedy sample)")
with gr.Row():
layout["inference"]["inputs"]["dynamic-sampling"] = gr.Checkbox(label="Dynamic Temperature", info="Dynamically adjusts the temperature based on the highest confident predicted token per sampling step.")
with gr.Row():
layout["inference"]["inputs"]["top-p"] = gr.Slider(value=1.0, minimum=0.0, maximum=1.0, step=0.05, label="Top P", info=r"Limits the samples that are outside the top P% of probabilities.")
layout["inference"]["inputs"]["top-k"] = gr.Slider(value=0, minimum=0, maximum=1024, step=1, label="Top K", info="Limits the samples to the top K of probabilities.")
layout["inference"]["inputs"]["beam-width"] = gr.Slider(value=0, minimum=0, maximum=32, step=1, label="Beam Width", info="Number of branches to search through for beam search sampling.")
with gr.Row():
layout["inference"]["inputs"]["repetition-penalty"] = gr.Slider(value=1.0, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty", info="Incurs a penalty to tokens based on how often they appear in a sequence.")
layout["inference"]["inputs"]["repetition-penalty-decay"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty Length Decay", info="Modifies the reptition penalty based on how far back in time the token appeared in the sequence.")
layout["inference"]["inputs"]["length-penalty"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Length Penalty", info="(AR only) Modifies the probability of a stop token based on the current length of the sequence.")
with gr.Row():
layout["inference"]["inputs"]["mirostat-tau"] = gr.Slider(value=0.0, minimum=0.0, maximum=8.0, step=0.05, label="Mirostat τ (Tau)", info="The \"surprise\" value when performing mirostat sampling. 0 to disable.")
layout["inference"]["inputs"]["mirostat-eta"] = gr.Slider(value=0.0, minimum=0.0, maximum=2.0, step=0.05, label="Mirostat η (Eta)", info="The \"learning rate\" during mirostat sampling applied to the maximum surprise.")
with gr.Row():
layout["inference"]["inputs"]["dry-multiplier"] = gr.Slider(value=0.0, minimum=0.0, maximum=8.0, step=0.05, label="DRY Multiplier", info="The multiplying factor for the DRY score penalty (0 to disable DRY sampling).")
layout["inference"]["inputs"]["dry-base"] = gr.Slider(value=1.75, minimum=0.0, maximum=8.0, step=0.05, label="DRY Base", info="The base of the exponent in the DRY score penalty")
layout["inference"]["inputs"]["dry-allowed-length"] = gr.Slider(value=2, minimum=0, maximum=75, step=1, label="Allowed Length", info="The maximimum length a token can be to perform DRY penalty with.")
layout["inference"]["buttons"]["inference"].click(
fn=do_inference,
inputs=[ x for x in layout["inference"]["inputs"].values() if x is not None],
outputs=[ x for x in layout["inference"]["outputs"].values() if x is not None]
)
"""
with gr.Tab("Training"):
with gr.Row():
with gr.Column(scale=1):
layout["training"]["outputs"]["console"] = gr.Textbox(lines=8, label="Console Log")
with gr.Row():
with gr.Column(scale=1):
layout["training"]["buttons"]["train"] = gr.Button(value="Train")
layout["training"]["buttons"]["train"].click(
fn=do_training,
outputs=[ x for x in layout["training"]["outputs"].values() if x is not None],
)
"""
with gr.Tab("Settings"):
with gr.Row():
with gr.Column(scale=7):
with gr.Row():
layout["settings"]["inputs"]["models"] = gr.Dropdown(choices=get_model_paths(), value=args.yaml, label="Model")
layout["settings"]["inputs"]["device"] = gr.Dropdown(choices=get_devices(), value="cuda", label="Device")
layout["settings"]["inputs"]["dtype"] = gr.Dropdown(choices=get_dtypes(), value="auto", label="Precision")
with gr.Column(scale=1):
layout["settings"]["buttons"]["load"] = gr.Button(value="Load Model")
layout["settings"]["buttons"]["load"].click(
fn=load_model,
inputs=[ x for x in layout["settings"]["inputs"].values() if x is not None],
outputs=[ x for x in layout["settings"]["outputs"].values() if x is not None],
)
if os.path.exists("README.md") and args.render_markdown:
md = open("README.md", "r", encoding="utf-8").read()
# remove HF's metadata
if md.startswith("---\n"):
md = "".join(md.split("---")[2:])
gr.Markdown(md)
def start( lock=True ):
ui.queue(max_size=8)
ui.launch(share=args.share, server_name=args.listen_host, server_port=args.listen_port, prevent_thread_lock=not lock)
if __name__ == "__main__":
start()