web UI support for HF ZeroGPU

This commit is contained in:
mrq 2024-11-05 21:22:29 -06:00
parent e58a9469a3
commit e094bd23ec

View File

@ -6,6 +6,7 @@ import argparse
import random import random
import tempfile import tempfile
import functools import functools
import spaces
import torch import torch
import numpy as np import numpy as np
@ -21,7 +22,17 @@ from .utils import get_devices, setup_logging, timer
from .utils.io import json_read, json_stringify from .utils.io import json_read, json_stringify
from .emb.qnt import decode_to_wave from .emb.qnt import decode_to_wave
from .data import get_lang_symmap, get_random_prompt from .data import get_lang_symmap, get_random_prompt
from .models.arch import AVAILABLE_ATTENTIONS
try:
import spaces
USING_SPACES = True
spaces_zerogpu_decorator = spaces.GPU
except ImportError:
USING_SPACES = False
def spaces_zerogpu_decorator(func):
return func
is_windows = sys.platform.startswith("win") is_windows = sys.platform.startswith("win")
@ -79,7 +90,6 @@ def get_model_paths( paths=[Path("./training/"), Path("./models/"), Path("./data
def get_dtypes(): def get_dtypes():
return ["float32", "float16", "bfloat16", "float8_e5m2", "float8_e4m3fn", "auto"] return ["float32", "float16", "bfloat16", "float8_e5m2", "float8_e4m3fn", "auto"]
from .models.arch import AVAILABLE_ATTENTIONS
def get_attentions(): def get_attentions():
return AVAILABLE_ATTENTIONS + ["auto"] return AVAILABLE_ATTENTIONS + ["auto"]
@ -158,6 +168,7 @@ def init_tts(config=None, lora=None, restart=False, device="cuda", dtype="auto",
tts = TTS( config=config, lora=args.lora, device=args.device, dtype=args.dtype if args.dtype != "auto" else None, amp=args.amp, attention=args.attention ) tts = TTS( config=config, lora=args.lora, device=args.device, dtype=args.dtype if args.dtype != "auto" else None, amp=args.amp, attention=args.attention )
return tts return tts
@spaces_zerogpu_decorator
@gradio_wrapper(inputs=layout["inference_tts"]["inputs"].keys()) @gradio_wrapper(inputs=layout["inference_tts"]["inputs"].keys())
def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ): def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
if not cfg.models: if not cfg.models:
@ -501,37 +512,39 @@ with ui:
) )
""" """
with gr.Tab("Dataset"): if not USING_SPACES:
with gr.Row(): with gr.Tab("Dataset"):
with gr.Column(scale=7): with gr.Row():
layout["dataset"]["outputs"]["transcription"] = gr.Textbox(lines=5, label="Sample Metadata") with gr.Column(scale=7):
with gr.Column(scale=1): layout["dataset"]["outputs"]["transcription"] = gr.Textbox(lines=5, label="Sample Metadata")
layout["dataset"]["inputs"]["speaker"] = gr.Dropdown(choices=get_speakers(), label="Speakers") with gr.Column(scale=1):
layout["dataset"]["outputs"]["audio"] = gr.Audio(label="Output") layout["dataset"]["inputs"]["speaker"] = gr.Dropdown(choices=get_speakers(), label="Speakers")
layout["dataset"]["buttons"]["sample"] = gr.Button(value="Sample") layout["dataset"]["outputs"]["audio"] = gr.Audio(label="Output")
layout["dataset"]["buttons"]["sample"] = gr.Button(value="Sample")
layout["dataset"]["buttons"]["sample"].click( layout["dataset"]["buttons"]["sample"].click(
fn=load_sample, fn=load_sample,
inputs=[ x for x in layout["dataset"]["inputs"].values() if x is not None], inputs=[ x for x in layout["dataset"]["inputs"].values() if x is not None],
outputs=[ x for x in layout["dataset"]["outputs"].values() if x is not None], outputs=[ x for x in layout["dataset"]["outputs"].values() if x is not None],
) )
with gr.Tab("Settings"): if not USING_SPACES:
with gr.Row(): with gr.Tab("Settings"):
with gr.Column(scale=7): with gr.Row():
with gr.Row(): with gr.Column(scale=7):
layout["settings"]["inputs"]["models"] = gr.Dropdown(choices=get_model_paths(), value=args.yaml or args.model, label="Model") with gr.Row():
layout["settings"]["inputs"]["device"] = gr.Dropdown(choices=get_devices(), value="cuda:0", label="Device") layout["settings"]["inputs"]["models"] = gr.Dropdown(choices=get_model_paths(), value=args.yaml or args.model, label="Model")
layout["settings"]["inputs"]["dtype"] = gr.Dropdown(choices=get_dtypes(), value="auto", label="Precision") layout["settings"]["inputs"]["device"] = gr.Dropdown(choices=get_devices(), value="cuda:0", label="Device")
layout["settings"]["inputs"]["attentions"] = gr.Dropdown(choices=get_attentions(), value="auto", label="Attentions") layout["settings"]["inputs"]["dtype"] = gr.Dropdown(choices=get_dtypes(), value="auto", label="Precision")
with gr.Column(scale=1): layout["settings"]["inputs"]["attentions"] = gr.Dropdown(choices=get_attentions(), value="auto", label="Attentions")
layout["settings"]["buttons"]["load"] = gr.Button(value="Load Model") with gr.Column(scale=1):
layout["settings"]["buttons"]["load"] = gr.Button(value="Load Model")
layout["settings"]["buttons"]["load"].click( layout["settings"]["buttons"]["load"].click(
fn=load_model, fn=load_model,
inputs=[ x for x in layout["settings"]["inputs"].values() if x is not None], inputs=[ x for x in layout["settings"]["inputs"].values() if x is not None],
outputs=[ x for x in layout["settings"]["outputs"].values() if x is not None], outputs=[ x for x in layout["settings"]["outputs"].values() if x is not None],
) )
if os.path.exists("README.md") and args.render_markdown: if os.path.exists("README.md") and args.render_markdown:
md = open("README.md", "r", encoding="utf-8").read() md = open("README.md", "r", encoding="utf-8").read()