re-added loading multiple models because I'm now entertaining having split AR/NAR models again (and need a way to load both at once)
This commit is contained in:
parent
b05a905b95
commit
b2194b859a
|
@ -13,7 +13,6 @@ def main():
|
|||
parser.add_argument("--out-path", type=Path, default=None)
|
||||
|
||||
parser.add_argument("--yaml", type=Path, default=None)
|
||||
parser.add_argument("--model-ckpt", type=Path, default=None)
|
||||
|
||||
parser.add_argument("--max-ar-steps", type=int, default=6 * 75)
|
||||
parser.add_argument("--max-nar-levels", type=int, default=7)
|
||||
|
@ -40,7 +39,7 @@ def main():
|
|||
parser.add_argument("--dtype", type=str, default=None)
|
||||
args = parser.parse_args()
|
||||
|
||||
tts = TTS( config=args.yaml, model_ckpt=args.model_ckpt, device=args.device, dtype=args.dtype, amp=args.amp )
|
||||
tts = TTS( config=args.yaml, device=args.device, dtype=args.dtype, amp=args.amp )
|
||||
tts.inference(
|
||||
text=args.text,
|
||||
references=args.references,
|
||||
|
|
|
@ -256,7 +256,7 @@ class Model:
|
|||
if self.interleave:
|
||||
name.append("interleaved")
|
||||
else:
|
||||
name.append(f'{cfg.model.prom_levels}')
|
||||
name.append(f'{self.prom_levels}')
|
||||
|
||||
|
||||
return "-".join(name)
|
||||
|
@ -627,8 +627,7 @@ class Config(_Config):
|
|||
experimental: bool = False # So I can stop commenting out things when committing
|
||||
|
||||
dataset: Dataset = field(default_factory=lambda: Dataset)
|
||||
model: Model = field(default_factory=lambda: Model)
|
||||
models: dict | list | None = None # deprecated
|
||||
models: dict | list | None = field(default_factory=lambda: [Model])
|
||||
hyperparameters: Hyperparameters = field(default_factory=lambda: Hyperparameters)
|
||||
evaluation: Evaluation = field(default_factory=lambda: Evaluation)
|
||||
trainer: Trainer = field(default_factory=lambda: Trainer)
|
||||
|
@ -643,6 +642,14 @@ class Config(_Config):
|
|||
|
||||
audio_backend: str = "vocos"
|
||||
|
||||
@property
|
||||
def model(self):
|
||||
for i, model in enumerate(self.models):
|
||||
if model.training:
|
||||
return model
|
||||
|
||||
return self.models[0]
|
||||
|
||||
@property
|
||||
def distributed(self):
|
||||
return world_size() > 1
|
||||
|
@ -681,8 +688,8 @@ class Config(_Config):
|
|||
if isinstance(self.dataset, type):
|
||||
self.dataset = dict()
|
||||
|
||||
if isinstance(self.model, type):
|
||||
self.model = dict()
|
||||
if isinstance(self.models, type):
|
||||
self.models = dict()
|
||||
|
||||
if isinstance(self.hyperparameters, type):
|
||||
self.hyperparameters = dict()
|
||||
|
@ -704,10 +711,14 @@ class Config(_Config):
|
|||
self.dataset.validation = [ Path(dir) for dir in self.dataset.validation ]
|
||||
self.dataset.noise = [ Path(dir) for dir in self.dataset.noise ]
|
||||
|
||||
"""
|
||||
if self.models is not None:
|
||||
self.model = Model(**next(iter(self.models)))
|
||||
else:
|
||||
self.model = Model(**self.model)
|
||||
"""
|
||||
|
||||
self.models = [ Model(**model) for model in self.models ]
|
||||
|
||||
self.hyperparameters = Hyperparameters(**self.hyperparameters)
|
||||
|
||||
|
|
|
@ -26,14 +26,14 @@ from functools import cache
|
|||
|
||||
@cache
|
||||
def load_engines(training=True):
|
||||
models = get_models(cfg.model.get(), training=training)
|
||||
models = get_models(cfg.models, training=training)
|
||||
engines = dict()
|
||||
|
||||
for name, model in models.items():
|
||||
optimizer = None
|
||||
lr_scheduler = None
|
||||
|
||||
inferencing = cfg.mode == "inferencing" or not model.hyper_config.training
|
||||
inferencing = cfg.mode == "inferencing" or not model.config.training
|
||||
backend = cfg.inference.backend if inferencing else cfg.trainer.backend
|
||||
dtype = cfg.inference.dtype if inferencing else cfg.trainer.dtype
|
||||
amp = cfg.inference.amp if inferencing else cfg.trainer.amp
|
||||
|
@ -43,7 +43,7 @@ def load_engines(training=True):
|
|||
engine_class = _Engine if backend == "local" or inferencing else Engine
|
||||
|
||||
if inferencing:
|
||||
model.hyper_config.training = False
|
||||
model.config.training = False
|
||||
|
||||
if cfg.optimizations.replace and cfg.optimizations.linear:
|
||||
model.model = ml.replace_linear( model.model )
|
||||
|
@ -83,7 +83,7 @@ def load_engines(training=True):
|
|||
params.update(cfg.hyperparameters.optimizer_params)
|
||||
|
||||
optimizer = optimizer_class(
|
||||
[ param for name, param in model.named_parameters() if name not in model.hyper_config.frozen_params ],
|
||||
[ param for name, param in model.named_parameters() if name not in model.config.frozen_params ],
|
||||
**params,
|
||||
)
|
||||
|
||||
|
@ -96,7 +96,7 @@ def load_engines(training=True):
|
|||
raise ValueError(f'ScheduleFree not implemented with requested optimizer: {cfg.hyperparameters.optimizer}')
|
||||
|
||||
optimizer = scheduler_class(
|
||||
[ param for name, param in model.named_parameters() if name not in model.hyper_config.frozen_params ],
|
||||
[ param for name, param in model.named_parameters() if name not in model.config.frozen_params ],
|
||||
lr = params['lr'],
|
||||
warmup_steps = cfg.hyperparameters.warmup_steps
|
||||
)
|
||||
|
@ -143,12 +143,16 @@ def load_engines(training=True):
|
|||
del state[k]
|
||||
|
||||
# resize text embedding
|
||||
if cfg.model.text_tokens != state["text_emb.weight"].shape[0]:
|
||||
state["text_emb.weight"] = state["text_emb.weight"][:cfg.model.text_tokens]
|
||||
if model.config.text_tokens != state["text_emb.weight"].shape[0]:
|
||||
state["text_emb.weight"] = state["text_emb.weight"][:model.config.text_tokens]
|
||||
|
||||
# resize text embedding
|
||||
if model.config.resp_levels != state["rvq_level_emb.weight"].shape[0]:
|
||||
state["rvq_level_emb.weight"] = state["rvq_level_emb.weight"][:model.config.resp_levels]
|
||||
|
||||
model.load_state_dict(state, strict=cfg.trainer.strict_loading)
|
||||
|
||||
hyper_config = model.hyper_config
|
||||
hyper_config = model.config
|
||||
|
||||
# wrap if DDP is requested
|
||||
if ddp:
|
||||
|
|
|
@ -19,7 +19,7 @@ if deepspeed_available:
|
|||
import deepspeed
|
||||
|
||||
class TTS():
|
||||
def __init__( self, config=None, model_ckpt=None, device=None, amp=None, dtype=None ):
|
||||
def __init__( self, config=None, device=None, amp=None, dtype=None ):
|
||||
self.loading = True
|
||||
|
||||
self.input_sample_rate = 24000
|
||||
|
@ -53,32 +53,12 @@ class TTS():
|
|||
|
||||
self.symmap = None
|
||||
|
||||
if model_ckpt:
|
||||
state = torch.load(model_ckpt)
|
||||
self.model = get_models(cfg.model.get(), training=False)[0]
|
||||
|
||||
if "userdata" in state and 'symmap' in state['userdata']:
|
||||
self.symmap = state['userdata']['symmap']
|
||||
elif "symmap" in state:
|
||||
self.symmap = state['symmap']
|
||||
self.engines = load_engines(training=False)
|
||||
for name, engine in self.engines.items():
|
||||
if self.dtype != torch.int8:
|
||||
engine.to(self.device, dtype=self.dtype if not self.amp else torch.float32)
|
||||
|
||||
if "module" in state:
|
||||
state = state['module']
|
||||
|
||||
self.model.load_state_dict(state)
|
||||
|
||||
if cfg.inference.backend == "local" and deepspeed_available and cfg.trainer.deepspeed.inferencing:
|
||||
self.model = deepspeed.init_inference(model=self.model, mp_size=1, replace_with_kernel_inject=True, dtype=dtype if not amp else torch.float32).module
|
||||
else:
|
||||
engines = load_engines(training=False)
|
||||
for name, engine in engines.items():
|
||||
self.model = engine.module
|
||||
break
|
||||
|
||||
if self.dtype != torch.int8:
|
||||
self.model = self.model.to(self.device, dtype=self.dtype if not self.amp else torch.float32)
|
||||
|
||||
self.model.eval()
|
||||
self.engines.eval()
|
||||
|
||||
if self.symmap is None:
|
||||
self.symmap = get_phone_symmap()
|
||||
|
@ -159,6 +139,15 @@ class TTS():
|
|||
wavs = []
|
||||
sr = None
|
||||
|
||||
model_ar = None
|
||||
model_nar = None
|
||||
|
||||
for name, engine in self.engines.items():
|
||||
if "ar" in engine.hyper_config.capabilities:
|
||||
model_ar = engine.module
|
||||
if "nar" in engine.hyper_config.capabilities:
|
||||
model_nar = engine.module
|
||||
|
||||
for line in lines:
|
||||
if out_path is None:
|
||||
out_path = f"./data/{cfg.start_time}.wav"
|
||||
|
@ -172,7 +161,7 @@ class TTS():
|
|||
lang = to_device(lang, self.device).to(torch.uint8)
|
||||
|
||||
with torch.autocast("cuda", dtype=self.dtype, enabled=self.amp):
|
||||
resps_list = self.model(
|
||||
resps_list = model_ar(
|
||||
text_list=[phns], proms_list=[prom], lang_list=[lang], max_steps=max_ar_steps, max_resp_context=max_ar_context,
|
||||
sampling_temperature=ar_temp,
|
||||
sampling_min_temperature=min_ar_temp,
|
||||
|
@ -183,8 +172,7 @@ class TTS():
|
|||
sampling_mirostat_tau=mirostat_tau,
|
||||
sampling_mirostat_eta=mirostat_eta,
|
||||
)
|
||||
resps_list = [r.unsqueeze(-1) for r in resps_list]
|
||||
resps_list = self.model(
|
||||
resps_list = model_nar(
|
||||
text_list=[phns], proms_list=[prom], lang_list=[lang], resps_list=resps_list,
|
||||
max_levels=max_nar_levels,
|
||||
sampling_temperature=nar_temp,
|
||||
|
|
|
@ -1,37 +1,36 @@
|
|||
|
||||
def get_model(cfg, training=True):
|
||||
name = cfg.name
|
||||
def get_model(config, training=True):
|
||||
name = config.name
|
||||
|
||||
if not cfg.experimental:
|
||||
if not config.experimental:
|
||||
from .ar_nar import AR_NAR
|
||||
model = AR_NAR(
|
||||
n_text_tokens=cfg.text_tokens,
|
||||
n_audio_tokens=cfg.audio_tokens,
|
||||
d_model=cfg.dim,
|
||||
n_heads=cfg.heads,
|
||||
n_layers=cfg.layers,
|
||||
n_experts=cfg.experts,
|
||||
n_text_tokens=config.text_tokens,
|
||||
n_audio_tokens=config.audio_tokens,
|
||||
d_model=config.dim,
|
||||
n_heads=config.heads,
|
||||
n_layers=config.layers,
|
||||
n_experts=config.experts,
|
||||
|
||||
p_dropout=cfg.dropout,
|
||||
p_dropout=config.dropout,
|
||||
|
||||
l_padding = cfg.input_alignment,
|
||||
l_padding = config.input_alignment,
|
||||
|
||||
training = training,
|
||||
config = cfg,
|
||||
config = config,
|
||||
)
|
||||
model._cfg = cfg
|
||||
else:
|
||||
from .experimental import Model as Experimental
|
||||
model = Experimental(
|
||||
n_text_tokens=cfg.text_tokens,
|
||||
n_audio_tokens=cfg.audio_tokens,
|
||||
|
||||
d_model=cfg.dim,
|
||||
n_layers=cfg.layers,
|
||||
n_heads=cfg.heads,
|
||||
p_dropout=cfg.dropout,
|
||||
n_text_tokens=config.text_tokens,
|
||||
n_audio_tokens=config.audio_tokens,
|
||||
|
||||
config = cfg,
|
||||
d_model=config.dim,
|
||||
n_layers=config.layers,
|
||||
n_heads=config.heads,
|
||||
p_dropout=config.dropout,
|
||||
|
||||
config = config,
|
||||
)
|
||||
|
||||
print(f"{name} ({next(model.parameters()).dtype}): {sum(p.numel() for p in model.parameters() if p.requires_grad)} parameters")
|
||||
|
|
|
@ -16,7 +16,7 @@ class AR_NAR(Base):
|
|||
@property
|
||||
def causal(self):
|
||||
if hasattr(self, "config") and self.config:
|
||||
return "ar" in self.capabilities
|
||||
return "ar" in self.config.capabilities
|
||||
return True
|
||||
|
||||
@property
|
||||
|
@ -31,6 +31,8 @@ class AR_NAR(Base):
|
|||
|
||||
@property
|
||||
def n_prom_levels(self) -> int:
|
||||
if hasattr(self, "config") and self.config:
|
||||
return self.config.prom_levels
|
||||
return cfg.model.prom_levels
|
||||
|
||||
@property
|
||||
|
@ -41,18 +43,26 @@ class AR_NAR(Base):
|
|||
|
||||
@property
|
||||
def n_max_levels(self) -> int:
|
||||
if hasattr(self, "config") and self.config:
|
||||
return self.config.max_levels
|
||||
return cfg.model.max_levels
|
||||
|
||||
@property
|
||||
def n_tasks(self) -> int:
|
||||
if hasattr(self, "config") and self.config:
|
||||
return self.config.tasks
|
||||
return cfg.model.tasks
|
||||
|
||||
@property
|
||||
def n_langs(self) -> int:
|
||||
if hasattr(self, "config") and self.config:
|
||||
return self.config.langs
|
||||
return cfg.model.langs
|
||||
|
||||
@property
|
||||
def n_tones(self) -> int:
|
||||
if hasattr(self, "config") and self.config:
|
||||
return self.config.tones
|
||||
return cfg.model.tones
|
||||
|
||||
@property
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
# https://github.com/kyegomez/BitNet
|
||||
from torch import Tensor, nn
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from bitnet.bit_transformer import Transformer as BitNetTransformerBlock, RMSNorm as BitNetRMSNorm
|
||||
|
||||
# re-enable logging because zetascale fucking sucks
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
# https://github.com/state-spaces/mamba
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel, MambaConfig, MixerModel as MambaMixelModel, layer_norm_fn as MambaLayerNormFn, RMSNorm as MambaRMSNorm
|
||||
|
||||
def MambaMixelModel_forward(self, input_ids=None, hidden_states=None, inference_params=None, **mixer_kwargs):
|
||||
|
|
|
@ -1,12 +1,13 @@
|
|||
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/mixtral/modeling_mixtral.py
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from transformers import MixtralModel, MixtralConfig
|
||||
from transformers.models.mixtral.modeling_mixtral import load_balancing_loss_func, MixtralSparseMoeBlock
|
||||
|
||||
# This is required because batch sizes > 1 throws errors
|
||||
def Fixed_MixtralSparseMoeBlock_forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
def MixtralSparseMoeBlock_forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
""" """
|
||||
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
||||
hidden_states = hidden_states.reshape(-1, hidden_dim) # was view()
|
||||
|
@ -41,5 +42,4 @@ def Fixed_MixtralSparseMoeBlock_forward(self, hidden_states: torch.Tensor) -> to
|
|||
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
|
||||
return final_hidden_states, router_logits
|
||||
|
||||
Original_MixtralSparseMoeBlock_forward = MixtralSparseMoeBlock.forward
|
||||
MixtralSparseMoeBlock.forward = Fixed_MixtralSparseMoeBlock_forward
|
||||
MixtralSparseMoeBlock.forward = MixtralSparseMoeBlock_forward
|
|
@ -207,9 +207,9 @@ class Base(nn.Module):
|
|||
return -100
|
||||
|
||||
def loss_factor(self, k):
|
||||
if self.hyper_config is None:
|
||||
if self.config is None:
|
||||
return 1.0
|
||||
return self.hyper_config.loss_factors[k] if k in self.hyper_config.loss_factors else 1.0
|
||||
return self.config.loss_factors[k] if k in self.config.loss_factors else 1.0
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -231,8 +231,8 @@ class Base(nn.Module):
|
|||
):
|
||||
super().__init__()
|
||||
self.training = training
|
||||
self.hyper_config = config
|
||||
self.gradient_checkpointing = self.hyper_config.gradient_checkpointing if self.hyper_config is not None else True
|
||||
self.config = config
|
||||
self.gradient_checkpointing = self.config.gradient_checkpointing if self.config is not None else True
|
||||
|
||||
self.n_text_tokens = n_text_tokens
|
||||
self.n_audio_tokens = n_audio_tokens
|
||||
|
@ -246,7 +246,7 @@ class Base(nn.Module):
|
|||
|
||||
# +1 to include the stop token
|
||||
n_prom_tokens = n_audio_tokens
|
||||
n_resp_tokens = n_audio_tokens + (1 if self.causal else 0) # AR requires a stop token to... know when to stop
|
||||
n_resp_tokens = n_audio_tokens + 1 # (1 if self.causal else 0) interoperability
|
||||
|
||||
self.text_emb = Embedding(n_text_tokens, d_model)
|
||||
self.langs_emb = None
|
||||
|
@ -263,13 +263,13 @@ class Base(nn.Module):
|
|||
self.proms_emb = AudioEmbedding(
|
||||
[n_prom_tokens] * self.n_prom_levels, d_model,
|
||||
levels=self.n_prom_levels if self.version > 3 else None,
|
||||
sums=self.hyper_config.audio_embedding_sums if self.hyper_config is not None else True,
|
||||
sums=self.config.audio_embedding_sums if self.config is not None else True,
|
||||
)
|
||||
# [1024 + STOP] + [1024] * 8
|
||||
self.resps_emb = AudioEmbedding(
|
||||
[n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1), d_model,
|
||||
levels=self.n_resp_levels if self.version > 3 else None,
|
||||
sums=self.hyper_config.audio_embedding_sums if self.hyper_config is not None else True
|
||||
sums=self.config.audio_embedding_sums if self.config is not None else True
|
||||
)
|
||||
|
||||
# useless since I actually removed using these with the input processing overhaul...
|
||||
|
@ -290,20 +290,20 @@ class Base(nn.Module):
|
|||
self.sep = nn.Parameter(torch.randn(d_model))
|
||||
|
||||
# ick, there has to be a better way
|
||||
hf_attention = self.hyper_config.attention if self.hyper_config is not None else None
|
||||
hf_attention = self.config.attention if self.config is not None else None
|
||||
|
||||
if self.hyper_config.attention == "auto":
|
||||
if self.config.attention == "auto":
|
||||
if "flash" in AVAILABLE_ATTENTIONS:
|
||||
self.hyper_config.attention = "flash"
|
||||
self.config.attention = "flash"
|
||||
elif "xformers" in AVAILABLE_ATTENTIONS:
|
||||
self.hyper_config.attention = "xformers"
|
||||
self.config.attention = "xformers"
|
||||
else:
|
||||
self.hyper_config.attention = "mem_efficient"
|
||||
self.config.attention = "mem_efficient"
|
||||
|
||||
if self.hyper_config.attention in ["xformers", "mem_efficient", "math", "flash"]:
|
||||
if self.config.attention in ["xformers", "mem_efficient", "math", "flash"]:
|
||||
hf_attention = None
|
||||
if self.hyper_config.attention not in AVAILABLE_ATTENTIONS:
|
||||
raise ValueError(f"Requesting attention `{self.hyper_config.attention}` but is not available. Currently available: {AVAILABLE_ATTENTIONS}")
|
||||
if self.config.attention not in AVAILABLE_ATTENTIONS:
|
||||
raise ValueError(f"Requesting attention `{self.config.attention}` but is not available. Currently available: {AVAILABLE_ATTENTIONS}")
|
||||
|
||||
|
||||
if self.arch_type == "transformer":
|
||||
|
@ -326,7 +326,7 @@ class Base(nn.Module):
|
|||
num_hidden_layers=n_layers,
|
||||
num_attention_heads=n_heads,
|
||||
attention_dropout=p_dropout if training else 0.0,
|
||||
num_key_value_heads=self.hyper_config.kv_heads if self.hyper_config.kv_heads > 0 else n_heads,
|
||||
num_key_value_heads=self.config.kv_heads if self.config.kv_heads > 0 else n_heads,
|
||||
hidden_act="gelu",
|
||||
is_encoder_decoder=False,
|
||||
is_decoder=True,
|
||||
|
@ -342,7 +342,7 @@ class Base(nn.Module):
|
|||
num_hidden_layers=n_layers,
|
||||
num_attention_heads=n_heads,
|
||||
attention_dropout=p_dropout if training else 0.0,
|
||||
num_key_value_heads=self.hyper_config.kv_heads if self.hyper_config.kv_heads > 0 else n_heads,
|
||||
num_key_value_heads=self.config.kv_heads if self.config.kv_heads > 0 else n_heads,
|
||||
sliding_window=75 * 12, # 12 second context window
|
||||
output_router_logits=training,
|
||||
hidden_act="gelu",
|
||||
|
@ -492,8 +492,8 @@ class Base(nn.Module):
|
|||
else:
|
||||
raise RuntimeError(f'Unknown arch specified: {self.arch_type}')
|
||||
|
||||
if self.hyper_config.attention in ["xformers", "auto", "mem_efficient", "math", "flash"]:
|
||||
self.model = ml.replace_attention( self.model, klass=LlamaAttention, target=LlamaAttention_Base, mode=self.hyper_config.attention )
|
||||
if self.config.attention in ["xformers", "auto", "mem_efficient", "math", "flash"]:
|
||||
self.model = ml.replace_attention( self.model, klass=LlamaAttention, target=LlamaAttention_Base, mode=self.config.attention )
|
||||
|
||||
self.classifier = nn.Linear(d_model, n_resp_tokens)
|
||||
|
||||
|
@ -691,7 +691,7 @@ class Base(nn.Module):
|
|||
quant_levels: Tensor | None = None,
|
||||
):
|
||||
# old, "naive" way, no loss factoring
|
||||
if not self.hyper_config.loss_factors:
|
||||
if not self.config.loss_factors:
|
||||
target_list = []
|
||||
for batch_index, batch in enumerate(inputs):
|
||||
target = []
|
||||
|
|
|
@ -54,13 +54,12 @@ def init_tts(restart=False):
|
|||
|
||||
parser = argparse.ArgumentParser(allow_abbrev=False)
|
||||
parser.add_argument("--yaml", type=Path, default=os.environ.get('VALLE_YAML', None)) # os environ so it can be specified in a HuggingFace Space too
|
||||
parser.add_argument("--model-ckpt", type=Path, default=None)
|
||||
parser.add_argument("--device", type=str, default="cuda")
|
||||
parser.add_argument("--amp", action="store_true")
|
||||
parser.add_argument("--dtype", type=str, default="auto")
|
||||
args, unknown = parser.parse_known_args()
|
||||
|
||||
tts = TTS( config=args.yaml, model_ckpt=args.model_ckpt, device=args.device, dtype=args.dtype if args.dtype != "auto" else None, amp=args.amp )
|
||||
tts = TTS( config=args.yaml, device=args.device, dtype=args.dtype if args.dtype != "auto" else None, amp=args.amp )
|
||||
return tts
|
||||
|
||||
@gradio_wrapper(inputs=layout["inference"]["inputs"].keys())
|
||||
|
|
Loading…
Reference in New Issue
Block a user