web UI support for HF ZeroGPU
This commit is contained in:
parent
e58a9469a3
commit
e094bd23ec
|
@ -6,6 +6,7 @@ import argparse
|
|||
import random
|
||||
import tempfile
|
||||
import functools
|
||||
import spaces
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
|
@ -21,7 +22,17 @@ from .utils import get_devices, setup_logging, timer
|
|||
from .utils.io import json_read, json_stringify
|
||||
from .emb.qnt import decode_to_wave
|
||||
from .data import get_lang_symmap, get_random_prompt
|
||||
from .models.arch import AVAILABLE_ATTENTIONS
|
||||
|
||||
try:
|
||||
import spaces
|
||||
|
||||
USING_SPACES = True
|
||||
spaces_zerogpu_decorator = spaces.GPU
|
||||
except ImportError:
|
||||
USING_SPACES = False
|
||||
def spaces_zerogpu_decorator(func):
|
||||
return func
|
||||
|
||||
is_windows = sys.platform.startswith("win")
|
||||
|
||||
|
@ -79,7 +90,6 @@ def get_model_paths( paths=[Path("./training/"), Path("./models/"), Path("./data
|
|||
def get_dtypes():
|
||||
return ["float32", "float16", "bfloat16", "float8_e5m2", "float8_e4m3fn", "auto"]
|
||||
|
||||
from .models.arch import AVAILABLE_ATTENTIONS
|
||||
def get_attentions():
|
||||
return AVAILABLE_ATTENTIONS + ["auto"]
|
||||
|
||||
|
@ -158,6 +168,7 @@ def init_tts(config=None, lora=None, restart=False, device="cuda", dtype="auto",
|
|||
tts = TTS( config=config, lora=args.lora, device=args.device, dtype=args.dtype if args.dtype != "auto" else None, amp=args.amp, attention=args.attention )
|
||||
return tts
|
||||
|
||||
@spaces_zerogpu_decorator
|
||||
@gradio_wrapper(inputs=layout["inference_tts"]["inputs"].keys())
|
||||
def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
||||
if not cfg.models:
|
||||
|
@ -501,37 +512,39 @@ with ui:
|
|||
)
|
||||
"""
|
||||
|
||||
with gr.Tab("Dataset"):
|
||||
with gr.Row():
|
||||
with gr.Column(scale=7):
|
||||
layout["dataset"]["outputs"]["transcription"] = gr.Textbox(lines=5, label="Sample Metadata")
|
||||
with gr.Column(scale=1):
|
||||
layout["dataset"]["inputs"]["speaker"] = gr.Dropdown(choices=get_speakers(), label="Speakers")
|
||||
layout["dataset"]["outputs"]["audio"] = gr.Audio(label="Output")
|
||||
layout["dataset"]["buttons"]["sample"] = gr.Button(value="Sample")
|
||||
if not USING_SPACES:
|
||||
with gr.Tab("Dataset"):
|
||||
with gr.Row():
|
||||
with gr.Column(scale=7):
|
||||
layout["dataset"]["outputs"]["transcription"] = gr.Textbox(lines=5, label="Sample Metadata")
|
||||
with gr.Column(scale=1):
|
||||
layout["dataset"]["inputs"]["speaker"] = gr.Dropdown(choices=get_speakers(), label="Speakers")
|
||||
layout["dataset"]["outputs"]["audio"] = gr.Audio(label="Output")
|
||||
layout["dataset"]["buttons"]["sample"] = gr.Button(value="Sample")
|
||||
|
||||
layout["dataset"]["buttons"]["sample"].click(
|
||||
fn=load_sample,
|
||||
inputs=[ x for x in layout["dataset"]["inputs"].values() if x is not None],
|
||||
outputs=[ x for x in layout["dataset"]["outputs"].values() if x is not None],
|
||||
)
|
||||
layout["dataset"]["buttons"]["sample"].click(
|
||||
fn=load_sample,
|
||||
inputs=[ x for x in layout["dataset"]["inputs"].values() if x is not None],
|
||||
outputs=[ x for x in layout["dataset"]["outputs"].values() if x is not None],
|
||||
)
|
||||
|
||||
with gr.Tab("Settings"):
|
||||
with gr.Row():
|
||||
with gr.Column(scale=7):
|
||||
with gr.Row():
|
||||
layout["settings"]["inputs"]["models"] = gr.Dropdown(choices=get_model_paths(), value=args.yaml or args.model, label="Model")
|
||||
layout["settings"]["inputs"]["device"] = gr.Dropdown(choices=get_devices(), value="cuda:0", label="Device")
|
||||
layout["settings"]["inputs"]["dtype"] = gr.Dropdown(choices=get_dtypes(), value="auto", label="Precision")
|
||||
layout["settings"]["inputs"]["attentions"] = gr.Dropdown(choices=get_attentions(), value="auto", label="Attentions")
|
||||
with gr.Column(scale=1):
|
||||
layout["settings"]["buttons"]["load"] = gr.Button(value="Load Model")
|
||||
if not USING_SPACES:
|
||||
with gr.Tab("Settings"):
|
||||
with gr.Row():
|
||||
with gr.Column(scale=7):
|
||||
with gr.Row():
|
||||
layout["settings"]["inputs"]["models"] = gr.Dropdown(choices=get_model_paths(), value=args.yaml or args.model, label="Model")
|
||||
layout["settings"]["inputs"]["device"] = gr.Dropdown(choices=get_devices(), value="cuda:0", label="Device")
|
||||
layout["settings"]["inputs"]["dtype"] = gr.Dropdown(choices=get_dtypes(), value="auto", label="Precision")
|
||||
layout["settings"]["inputs"]["attentions"] = gr.Dropdown(choices=get_attentions(), value="auto", label="Attentions")
|
||||
with gr.Column(scale=1):
|
||||
layout["settings"]["buttons"]["load"] = gr.Button(value="Load Model")
|
||||
|
||||
layout["settings"]["buttons"]["load"].click(
|
||||
fn=load_model,
|
||||
inputs=[ x for x in layout["settings"]["inputs"].values() if x is not None],
|
||||
outputs=[ x for x in layout["settings"]["outputs"].values() if x is not None],
|
||||
)
|
||||
layout["settings"]["buttons"]["load"].click(
|
||||
fn=load_model,
|
||||
inputs=[ x for x in layout["settings"]["inputs"].values() if x is not None],
|
||||
outputs=[ x for x in layout["settings"]["outputs"].values() if x is not None],
|
||||
)
|
||||
|
||||
if os.path.exists("README.md") and args.render_markdown:
|
||||
md = open("README.md", "r", encoding="utf-8").read()
|
||||
|
|
Loading…
Reference in New Issue
Block a user