web UI support for HF ZeroGPU

This commit is contained in:
mrq 2024-11-05 21:10:26 -06:00
parent e58a9469a3
commit 57dc3e89d5

View File

@ -6,6 +6,7 @@ import argparse
import random
import tempfile
import functools
import spaces
import torch
import numpy as np
@ -22,6 +23,14 @@ from .utils.io import json_read, json_stringify
from .emb.qnt import decode_to_wave
from .data import get_lang_symmap, get_random_prompt
try:
import spaces
USING_SPACES = True
spaces_zerogpu_decorator = spaces.GPU
except ImportError:
USING_SPACES = False
spaces_zerogpu_decorator = lambda func: func
is_windows = sys.platform.startswith("win")
@ -158,6 +167,7 @@ def init_tts(config=None, lora=None, restart=False, device="cuda", dtype="auto",
tts = TTS( config=config, lora=args.lora, device=args.device, dtype=args.dtype if args.dtype != "auto" else None, amp=args.amp, attention=args.attention )
return tts
@spaces_zerogpu_decorator
@gradio_wrapper(inputs=layout["inference_tts"]["inputs"].keys())
def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
if not cfg.models: