web UI support for HF ZeroGPU
This commit is contained in:
parent
e58a9469a3
commit
e094bd23ec
|
@ -6,6 +6,7 @@ import argparse
|
||||||
import random
|
import random
|
||||||
import tempfile
|
import tempfile
|
||||||
import functools
|
import functools
|
||||||
|
import spaces
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -21,7 +22,17 @@ from .utils import get_devices, setup_logging, timer
|
||||||
from .utils.io import json_read, json_stringify
|
from .utils.io import json_read, json_stringify
|
||||||
from .emb.qnt import decode_to_wave
|
from .emb.qnt import decode_to_wave
|
||||||
from .data import get_lang_symmap, get_random_prompt
|
from .data import get_lang_symmap, get_random_prompt
|
||||||
|
from .models.arch import AVAILABLE_ATTENTIONS
|
||||||
|
|
||||||
|
try:
|
||||||
|
import spaces
|
||||||
|
|
||||||
|
USING_SPACES = True
|
||||||
|
spaces_zerogpu_decorator = spaces.GPU
|
||||||
|
except ImportError:
|
||||||
|
USING_SPACES = False
|
||||||
|
def spaces_zerogpu_decorator(func):
|
||||||
|
return func
|
||||||
|
|
||||||
is_windows = sys.platform.startswith("win")
|
is_windows = sys.platform.startswith("win")
|
||||||
|
|
||||||
|
@ -79,7 +90,6 @@ def get_model_paths( paths=[Path("./training/"), Path("./models/"), Path("./data
|
||||||
def get_dtypes():
|
def get_dtypes():
|
||||||
return ["float32", "float16", "bfloat16", "float8_e5m2", "float8_e4m3fn", "auto"]
|
return ["float32", "float16", "bfloat16", "float8_e5m2", "float8_e4m3fn", "auto"]
|
||||||
|
|
||||||
from .models.arch import AVAILABLE_ATTENTIONS
|
|
||||||
def get_attentions():
|
def get_attentions():
|
||||||
return AVAILABLE_ATTENTIONS + ["auto"]
|
return AVAILABLE_ATTENTIONS + ["auto"]
|
||||||
|
|
||||||
|
@ -158,6 +168,7 @@ def init_tts(config=None, lora=None, restart=False, device="cuda", dtype="auto",
|
||||||
tts = TTS( config=config, lora=args.lora, device=args.device, dtype=args.dtype if args.dtype != "auto" else None, amp=args.amp, attention=args.attention )
|
tts = TTS( config=config, lora=args.lora, device=args.device, dtype=args.dtype if args.dtype != "auto" else None, amp=args.amp, attention=args.attention )
|
||||||
return tts
|
return tts
|
||||||
|
|
||||||
|
@spaces_zerogpu_decorator
|
||||||
@gradio_wrapper(inputs=layout["inference_tts"]["inputs"].keys())
|
@gradio_wrapper(inputs=layout["inference_tts"]["inputs"].keys())
|
||||||
def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
||||||
if not cfg.models:
|
if not cfg.models:
|
||||||
|
@ -501,6 +512,7 @@ with ui:
|
||||||
)
|
)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if not USING_SPACES:
|
||||||
with gr.Tab("Dataset"):
|
with gr.Tab("Dataset"):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column(scale=7):
|
with gr.Column(scale=7):
|
||||||
|
@ -516,6 +528,7 @@ with ui:
|
||||||
outputs=[ x for x in layout["dataset"]["outputs"].values() if x is not None],
|
outputs=[ x for x in layout["dataset"]["outputs"].values() if x is not None],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if not USING_SPACES:
|
||||||
with gr.Tab("Settings"):
|
with gr.Tab("Settings"):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column(scale=7):
|
with gr.Column(scale=7):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user