web UI support for HF ZeroGPU

This commit is contained in:
mrq 2024-11-05 21:22:29 -06:00
parent e58a9469a3
commit e094bd23ec

View File

@ -6,6 +6,7 @@ import argparse
import random import random
import tempfile import tempfile
import functools import functools
import spaces
import torch import torch
import numpy as np import numpy as np
@ -21,7 +22,17 @@ from .utils import get_devices, setup_logging, timer
from .utils.io import json_read, json_stringify from .utils.io import json_read, json_stringify
from .emb.qnt import decode_to_wave from .emb.qnt import decode_to_wave
from .data import get_lang_symmap, get_random_prompt from .data import get_lang_symmap, get_random_prompt
from .models.arch import AVAILABLE_ATTENTIONS
try:
import spaces
USING_SPACES = True
spaces_zerogpu_decorator = spaces.GPU
except ImportError:
USING_SPACES = False
def spaces_zerogpu_decorator(func):
return func
is_windows = sys.platform.startswith("win") is_windows = sys.platform.startswith("win")
@ -79,7 +90,6 @@ def get_model_paths( paths=[Path("./training/"), Path("./models/"), Path("./data
def get_dtypes(): def get_dtypes():
return ["float32", "float16", "bfloat16", "float8_e5m2", "float8_e4m3fn", "auto"] return ["float32", "float16", "bfloat16", "float8_e5m2", "float8_e4m3fn", "auto"]
from .models.arch import AVAILABLE_ATTENTIONS
def get_attentions(): def get_attentions():
return AVAILABLE_ATTENTIONS + ["auto"] return AVAILABLE_ATTENTIONS + ["auto"]
@ -158,6 +168,7 @@ def init_tts(config=None, lora=None, restart=False, device="cuda", dtype="auto",
tts = TTS( config=config, lora=args.lora, device=args.device, dtype=args.dtype if args.dtype != "auto" else None, amp=args.amp, attention=args.attention ) tts = TTS( config=config, lora=args.lora, device=args.device, dtype=args.dtype if args.dtype != "auto" else None, amp=args.amp, attention=args.attention )
return tts return tts
@spaces_zerogpu_decorator
@gradio_wrapper(inputs=layout["inference_tts"]["inputs"].keys()) @gradio_wrapper(inputs=layout["inference_tts"]["inputs"].keys())
def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ): def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
if not cfg.models: if not cfg.models:
@ -501,6 +512,7 @@ with ui:
) )
""" """
if not USING_SPACES:
with gr.Tab("Dataset"): with gr.Tab("Dataset"):
with gr.Row(): with gr.Row():
with gr.Column(scale=7): with gr.Column(scale=7):
@ -516,6 +528,7 @@ with ui:
outputs=[ x for x in layout["dataset"]["outputs"].values() if x is not None], outputs=[ x for x in layout["dataset"]["outputs"].values() if x is not None],
) )
if not USING_SPACES:
with gr.Tab("Settings"): with gr.Tab("Settings"):
with gr.Row(): with gr.Row():
with gr.Column(scale=7): with gr.Column(scale=7):