web UI support for HF ZeroGPU
This commit is contained in:
parent
e58a9469a3
commit
e094bd23ec
|
@ -6,6 +6,7 @@ import argparse
|
||||||
import random
|
import random
|
||||||
import tempfile
|
import tempfile
|
||||||
import functools
|
import functools
|
||||||
|
import spaces
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -21,7 +22,17 @@ from .utils import get_devices, setup_logging, timer
|
||||||
from .utils.io import json_read, json_stringify
|
from .utils.io import json_read, json_stringify
|
||||||
from .emb.qnt import decode_to_wave
|
from .emb.qnt import decode_to_wave
|
||||||
from .data import get_lang_symmap, get_random_prompt
|
from .data import get_lang_symmap, get_random_prompt
|
||||||
|
from .models.arch import AVAILABLE_ATTENTIONS
|
||||||
|
|
||||||
|
try:
|
||||||
|
import spaces
|
||||||
|
|
||||||
|
USING_SPACES = True
|
||||||
|
spaces_zerogpu_decorator = spaces.GPU
|
||||||
|
except ImportError:
|
||||||
|
USING_SPACES = False
|
||||||
|
def spaces_zerogpu_decorator(func):
|
||||||
|
return func
|
||||||
|
|
||||||
is_windows = sys.platform.startswith("win")
|
is_windows = sys.platform.startswith("win")
|
||||||
|
|
||||||
|
@ -79,7 +90,6 @@ def get_model_paths( paths=[Path("./training/"), Path("./models/"), Path("./data
|
||||||
def get_dtypes():
|
def get_dtypes():
|
||||||
return ["float32", "float16", "bfloat16", "float8_e5m2", "float8_e4m3fn", "auto"]
|
return ["float32", "float16", "bfloat16", "float8_e5m2", "float8_e4m3fn", "auto"]
|
||||||
|
|
||||||
from .models.arch import AVAILABLE_ATTENTIONS
|
|
||||||
def get_attentions():
|
def get_attentions():
|
||||||
return AVAILABLE_ATTENTIONS + ["auto"]
|
return AVAILABLE_ATTENTIONS + ["auto"]
|
||||||
|
|
||||||
|
@ -158,6 +168,7 @@ def init_tts(config=None, lora=None, restart=False, device="cuda", dtype="auto",
|
||||||
tts = TTS( config=config, lora=args.lora, device=args.device, dtype=args.dtype if args.dtype != "auto" else None, amp=args.amp, attention=args.attention )
|
tts = TTS( config=config, lora=args.lora, device=args.device, dtype=args.dtype if args.dtype != "auto" else None, amp=args.amp, attention=args.attention )
|
||||||
return tts
|
return tts
|
||||||
|
|
||||||
|
@spaces_zerogpu_decorator
|
||||||
@gradio_wrapper(inputs=layout["inference_tts"]["inputs"].keys())
|
@gradio_wrapper(inputs=layout["inference_tts"]["inputs"].keys())
|
||||||
def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
||||||
if not cfg.models:
|
if not cfg.models:
|
||||||
|
@ -501,37 +512,39 @@ with ui:
|
||||||
)
|
)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
with gr.Tab("Dataset"):
|
if not USING_SPACES:
|
||||||
with gr.Row():
|
with gr.Tab("Dataset"):
|
||||||
with gr.Column(scale=7):
|
with gr.Row():
|
||||||
layout["dataset"]["outputs"]["transcription"] = gr.Textbox(lines=5, label="Sample Metadata")
|
with gr.Column(scale=7):
|
||||||
with gr.Column(scale=1):
|
layout["dataset"]["outputs"]["transcription"] = gr.Textbox(lines=5, label="Sample Metadata")
|
||||||
layout["dataset"]["inputs"]["speaker"] = gr.Dropdown(choices=get_speakers(), label="Speakers")
|
with gr.Column(scale=1):
|
||||||
layout["dataset"]["outputs"]["audio"] = gr.Audio(label="Output")
|
layout["dataset"]["inputs"]["speaker"] = gr.Dropdown(choices=get_speakers(), label="Speakers")
|
||||||
layout["dataset"]["buttons"]["sample"] = gr.Button(value="Sample")
|
layout["dataset"]["outputs"]["audio"] = gr.Audio(label="Output")
|
||||||
|
layout["dataset"]["buttons"]["sample"] = gr.Button(value="Sample")
|
||||||
|
|
||||||
layout["dataset"]["buttons"]["sample"].click(
|
layout["dataset"]["buttons"]["sample"].click(
|
||||||
fn=load_sample,
|
fn=load_sample,
|
||||||
inputs=[ x for x in layout["dataset"]["inputs"].values() if x is not None],
|
inputs=[ x for x in layout["dataset"]["inputs"].values() if x is not None],
|
||||||
outputs=[ x for x in layout["dataset"]["outputs"].values() if x is not None],
|
outputs=[ x for x in layout["dataset"]["outputs"].values() if x is not None],
|
||||||
)
|
)
|
||||||
|
|
||||||
with gr.Tab("Settings"):
|
if not USING_SPACES:
|
||||||
with gr.Row():
|
with gr.Tab("Settings"):
|
||||||
with gr.Column(scale=7):
|
with gr.Row():
|
||||||
with gr.Row():
|
with gr.Column(scale=7):
|
||||||
layout["settings"]["inputs"]["models"] = gr.Dropdown(choices=get_model_paths(), value=args.yaml or args.model, label="Model")
|
with gr.Row():
|
||||||
layout["settings"]["inputs"]["device"] = gr.Dropdown(choices=get_devices(), value="cuda:0", label="Device")
|
layout["settings"]["inputs"]["models"] = gr.Dropdown(choices=get_model_paths(), value=args.yaml or args.model, label="Model")
|
||||||
layout["settings"]["inputs"]["dtype"] = gr.Dropdown(choices=get_dtypes(), value="auto", label="Precision")
|
layout["settings"]["inputs"]["device"] = gr.Dropdown(choices=get_devices(), value="cuda:0", label="Device")
|
||||||
layout["settings"]["inputs"]["attentions"] = gr.Dropdown(choices=get_attentions(), value="auto", label="Attentions")
|
layout["settings"]["inputs"]["dtype"] = gr.Dropdown(choices=get_dtypes(), value="auto", label="Precision")
|
||||||
with gr.Column(scale=1):
|
layout["settings"]["inputs"]["attentions"] = gr.Dropdown(choices=get_attentions(), value="auto", label="Attentions")
|
||||||
layout["settings"]["buttons"]["load"] = gr.Button(value="Load Model")
|
with gr.Column(scale=1):
|
||||||
|
layout["settings"]["buttons"]["load"] = gr.Button(value="Load Model")
|
||||||
|
|
||||||
layout["settings"]["buttons"]["load"].click(
|
layout["settings"]["buttons"]["load"].click(
|
||||||
fn=load_model,
|
fn=load_model,
|
||||||
inputs=[ x for x in layout["settings"]["inputs"].values() if x is not None],
|
inputs=[ x for x in layout["settings"]["inputs"].values() if x is not None],
|
||||||
outputs=[ x for x in layout["settings"]["outputs"].values() if x is not None],
|
outputs=[ x for x in layout["settings"]["outputs"].values() if x is not None],
|
||||||
)
|
)
|
||||||
|
|
||||||
if os.path.exists("README.md") and args.render_markdown:
|
if os.path.exists("README.md") and args.render_markdown:
|
||||||
md = open("README.md", "r", encoding="utf-8").read()
|
md = open("README.md", "r", encoding="utf-8").read()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user