From e094bd23ecbed657349cf4b8d83595b57918eef3 Mon Sep 17 00:00:00 2001 From: mrq Date: Tue, 5 Nov 2024 21:22:29 -0600 Subject: [PATCH] web UI support for HF ZeroGPU --- vall_e/webui.py | 71 +++++++++++++++++++++++++++++-------------------- 1 file changed, 42 insertions(+), 29 deletions(-) diff --git a/vall_e/webui.py b/vall_e/webui.py index 2a1f4c8..07fab24 100644 --- a/vall_e/webui.py +++ b/vall_e/webui.py @@ -6,6 +6,7 @@ import argparse import random import tempfile import functools +import spaces import torch import numpy as np @@ -21,7 +22,17 @@ from .utils import get_devices, setup_logging, timer from .utils.io import json_read, json_stringify from .emb.qnt import decode_to_wave from .data import get_lang_symmap, get_random_prompt +from .models.arch import AVAILABLE_ATTENTIONS +try: + import spaces + + USING_SPACES = True + spaces_zerogpu_decorator = spaces.GPU +except ImportError: + USING_SPACES = False + def spaces_zerogpu_decorator(func): + return func is_windows = sys.platform.startswith("win") @@ -79,7 +90,6 @@ def get_model_paths( paths=[Path("./training/"), Path("./models/"), Path("./data def get_dtypes(): return ["float32", "float16", "bfloat16", "float8_e5m2", "float8_e4m3fn", "auto"] -from .models.arch import AVAILABLE_ATTENTIONS def get_attentions(): return AVAILABLE_ATTENTIONS + ["auto"] @@ -158,6 +168,7 @@ def init_tts(config=None, lora=None, restart=False, device="cuda", dtype="auto", tts = TTS( config=config, lora=args.lora, device=args.device, dtype=args.dtype if args.dtype != "auto" else None, amp=args.amp, attention=args.attention ) return tts +@spaces_zerogpu_decorator @gradio_wrapper(inputs=layout["inference_tts"]["inputs"].keys()) def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ): if not cfg.models: @@ -501,37 +512,39 @@ with ui: ) """ - with gr.Tab("Dataset"): - with gr.Row(): - with gr.Column(scale=7): - layout["dataset"]["outputs"]["transcription"] = gr.Textbox(lines=5, label="Sample Metadata") - with gr.Column(scale=1): - layout["dataset"]["inputs"]["speaker"] = gr.Dropdown(choices=get_speakers(), label="Speakers") - layout["dataset"]["outputs"]["audio"] = gr.Audio(label="Output") - layout["dataset"]["buttons"]["sample"] = gr.Button(value="Sample") + if not USING_SPACES: + with gr.Tab("Dataset"): + with gr.Row(): + with gr.Column(scale=7): + layout["dataset"]["outputs"]["transcription"] = gr.Textbox(lines=5, label="Sample Metadata") + with gr.Column(scale=1): + layout["dataset"]["inputs"]["speaker"] = gr.Dropdown(choices=get_speakers(), label="Speakers") + layout["dataset"]["outputs"]["audio"] = gr.Audio(label="Output") + layout["dataset"]["buttons"]["sample"] = gr.Button(value="Sample") - layout["dataset"]["buttons"]["sample"].click( - fn=load_sample, - inputs=[ x for x in layout["dataset"]["inputs"].values() if x is not None], - outputs=[ x for x in layout["dataset"]["outputs"].values() if x is not None], - ) + layout["dataset"]["buttons"]["sample"].click( + fn=load_sample, + inputs=[ x for x in layout["dataset"]["inputs"].values() if x is not None], + outputs=[ x for x in layout["dataset"]["outputs"].values() if x is not None], + ) - with gr.Tab("Settings"): - with gr.Row(): - with gr.Column(scale=7): - with gr.Row(): - layout["settings"]["inputs"]["models"] = gr.Dropdown(choices=get_model_paths(), value=args.yaml or args.model, label="Model") - layout["settings"]["inputs"]["device"] = gr.Dropdown(choices=get_devices(), value="cuda:0", label="Device") - layout["settings"]["inputs"]["dtype"] = gr.Dropdown(choices=get_dtypes(), value="auto", label="Precision") - layout["settings"]["inputs"]["attentions"] = gr.Dropdown(choices=get_attentions(), value="auto", label="Attentions") - with gr.Column(scale=1): - layout["settings"]["buttons"]["load"] = gr.Button(value="Load Model") + if not USING_SPACES: + with gr.Tab("Settings"): + with gr.Row(): + with gr.Column(scale=7): + with gr.Row(): + layout["settings"]["inputs"]["models"] = gr.Dropdown(choices=get_model_paths(), value=args.yaml or args.model, label="Model") + layout["settings"]["inputs"]["device"] = gr.Dropdown(choices=get_devices(), value="cuda:0", label="Device") + layout["settings"]["inputs"]["dtype"] = gr.Dropdown(choices=get_dtypes(), value="auto", label="Precision") + layout["settings"]["inputs"]["attentions"] = gr.Dropdown(choices=get_attentions(), value="auto", label="Attentions") + with gr.Column(scale=1): + layout["settings"]["buttons"]["load"] = gr.Button(value="Load Model") - layout["settings"]["buttons"]["load"].click( - fn=load_model, - inputs=[ x for x in layout["settings"]["inputs"].values() if x is not None], - outputs=[ x for x in layout["settings"]["outputs"].values() if x is not None], - ) + layout["settings"]["buttons"]["load"].click( + fn=load_model, + inputs=[ x for x in layout["settings"]["inputs"].values() if x is not None], + outputs=[ x for x in layout["settings"]["outputs"].values() if x is not None], + ) if os.path.exists("README.md") and args.render_markdown: md = open("README.md", "r", encoding="utf-8").read()