re-added loading multiple models because I'm now entertaining having split AR/NAR models again (and need a way to load both at once)

This commit is contained in:
mrq 2024-06-06 09:48:43 -05:00
parent b05a905b95
commit b2194b859a
11 changed files with 105 additions and 91 deletions

View File

@ -13,7 +13,6 @@ def main():
parser.add_argument("--out-path", type=Path, default=None) parser.add_argument("--out-path", type=Path, default=None)
parser.add_argument("--yaml", type=Path, default=None) parser.add_argument("--yaml", type=Path, default=None)
parser.add_argument("--model-ckpt", type=Path, default=None)
parser.add_argument("--max-ar-steps", type=int, default=6 * 75) parser.add_argument("--max-ar-steps", type=int, default=6 * 75)
parser.add_argument("--max-nar-levels", type=int, default=7) parser.add_argument("--max-nar-levels", type=int, default=7)
@ -40,7 +39,7 @@ def main():
parser.add_argument("--dtype", type=str, default=None) parser.add_argument("--dtype", type=str, default=None)
args = parser.parse_args() args = parser.parse_args()
tts = TTS( config=args.yaml, model_ckpt=args.model_ckpt, device=args.device, dtype=args.dtype, amp=args.amp ) tts = TTS( config=args.yaml, device=args.device, dtype=args.dtype, amp=args.amp )
tts.inference( tts.inference(
text=args.text, text=args.text,
references=args.references, references=args.references,

View File

@ -256,7 +256,7 @@ class Model:
if self.interleave: if self.interleave:
name.append("interleaved") name.append("interleaved")
else: else:
name.append(f'{cfg.model.prom_levels}') name.append(f'{self.prom_levels}')
return "-".join(name) return "-".join(name)
@ -627,8 +627,7 @@ class Config(_Config):
experimental: bool = False # So I can stop commenting out things when committing experimental: bool = False # So I can stop commenting out things when committing
dataset: Dataset = field(default_factory=lambda: Dataset) dataset: Dataset = field(default_factory=lambda: Dataset)
model: Model = field(default_factory=lambda: Model) models: dict | list | None = field(default_factory=lambda: [Model])
models: dict | list | None = None # deprecated
hyperparameters: Hyperparameters = field(default_factory=lambda: Hyperparameters) hyperparameters: Hyperparameters = field(default_factory=lambda: Hyperparameters)
evaluation: Evaluation = field(default_factory=lambda: Evaluation) evaluation: Evaluation = field(default_factory=lambda: Evaluation)
trainer: Trainer = field(default_factory=lambda: Trainer) trainer: Trainer = field(default_factory=lambda: Trainer)
@ -643,6 +642,14 @@ class Config(_Config):
audio_backend: str = "vocos" audio_backend: str = "vocos"
@property
def model(self):
for i, model in enumerate(self.models):
if model.training:
return model
return self.models[0]
@property @property
def distributed(self): def distributed(self):
return world_size() > 1 return world_size() > 1
@ -681,8 +688,8 @@ class Config(_Config):
if isinstance(self.dataset, type): if isinstance(self.dataset, type):
self.dataset = dict() self.dataset = dict()
if isinstance(self.model, type): if isinstance(self.models, type):
self.model = dict() self.models = dict()
if isinstance(self.hyperparameters, type): if isinstance(self.hyperparameters, type):
self.hyperparameters = dict() self.hyperparameters = dict()
@ -704,10 +711,14 @@ class Config(_Config):
self.dataset.validation = [ Path(dir) for dir in self.dataset.validation ] self.dataset.validation = [ Path(dir) for dir in self.dataset.validation ]
self.dataset.noise = [ Path(dir) for dir in self.dataset.noise ] self.dataset.noise = [ Path(dir) for dir in self.dataset.noise ]
"""
if self.models is not None: if self.models is not None:
self.model = Model(**next(iter(self.models))) self.model = Model(**next(iter(self.models)))
else: else:
self.model = Model(**self.model) self.model = Model(**self.model)
"""
self.models = [ Model(**model) for model in self.models ]
self.hyperparameters = Hyperparameters(**self.hyperparameters) self.hyperparameters = Hyperparameters(**self.hyperparameters)

View File

@ -26,14 +26,14 @@ from functools import cache
@cache @cache
def load_engines(training=True): def load_engines(training=True):
models = get_models(cfg.model.get(), training=training) models = get_models(cfg.models, training=training)
engines = dict() engines = dict()
for name, model in models.items(): for name, model in models.items():
optimizer = None optimizer = None
lr_scheduler = None lr_scheduler = None
inferencing = cfg.mode == "inferencing" or not model.hyper_config.training inferencing = cfg.mode == "inferencing" or not model.config.training
backend = cfg.inference.backend if inferencing else cfg.trainer.backend backend = cfg.inference.backend if inferencing else cfg.trainer.backend
dtype = cfg.inference.dtype if inferencing else cfg.trainer.dtype dtype = cfg.inference.dtype if inferencing else cfg.trainer.dtype
amp = cfg.inference.amp if inferencing else cfg.trainer.amp amp = cfg.inference.amp if inferencing else cfg.trainer.amp
@ -43,7 +43,7 @@ def load_engines(training=True):
engine_class = _Engine if backend == "local" or inferencing else Engine engine_class = _Engine if backend == "local" or inferencing else Engine
if inferencing: if inferencing:
model.hyper_config.training = False model.config.training = False
if cfg.optimizations.replace and cfg.optimizations.linear: if cfg.optimizations.replace and cfg.optimizations.linear:
model.model = ml.replace_linear( model.model ) model.model = ml.replace_linear( model.model )
@ -83,7 +83,7 @@ def load_engines(training=True):
params.update(cfg.hyperparameters.optimizer_params) params.update(cfg.hyperparameters.optimizer_params)
optimizer = optimizer_class( optimizer = optimizer_class(
[ param for name, param in model.named_parameters() if name not in model.hyper_config.frozen_params ], [ param for name, param in model.named_parameters() if name not in model.config.frozen_params ],
**params, **params,
) )
@ -96,7 +96,7 @@ def load_engines(training=True):
raise ValueError(f'ScheduleFree not implemented with requested optimizer: {cfg.hyperparameters.optimizer}') raise ValueError(f'ScheduleFree not implemented with requested optimizer: {cfg.hyperparameters.optimizer}')
optimizer = scheduler_class( optimizer = scheduler_class(
[ param for name, param in model.named_parameters() if name not in model.hyper_config.frozen_params ], [ param for name, param in model.named_parameters() if name not in model.config.frozen_params ],
lr = params['lr'], lr = params['lr'],
warmup_steps = cfg.hyperparameters.warmup_steps warmup_steps = cfg.hyperparameters.warmup_steps
) )
@ -143,12 +143,16 @@ def load_engines(training=True):
del state[k] del state[k]
# resize text embedding # resize text embedding
if cfg.model.text_tokens != state["text_emb.weight"].shape[0]: if model.config.text_tokens != state["text_emb.weight"].shape[0]:
state["text_emb.weight"] = state["text_emb.weight"][:cfg.model.text_tokens] state["text_emb.weight"] = state["text_emb.weight"][:model.config.text_tokens]
# resize text embedding
if model.config.resp_levels != state["rvq_level_emb.weight"].shape[0]:
state["rvq_level_emb.weight"] = state["rvq_level_emb.weight"][:model.config.resp_levels]
model.load_state_dict(state, strict=cfg.trainer.strict_loading) model.load_state_dict(state, strict=cfg.trainer.strict_loading)
hyper_config = model.hyper_config hyper_config = model.config
# wrap if DDP is requested # wrap if DDP is requested
if ddp: if ddp:

View File

@ -19,7 +19,7 @@ if deepspeed_available:
import deepspeed import deepspeed
class TTS(): class TTS():
def __init__( self, config=None, model_ckpt=None, device=None, amp=None, dtype=None ): def __init__( self, config=None, device=None, amp=None, dtype=None ):
self.loading = True self.loading = True
self.input_sample_rate = 24000 self.input_sample_rate = 24000
@ -53,32 +53,12 @@ class TTS():
self.symmap = None self.symmap = None
if model_ckpt: self.engines = load_engines(training=False)
state = torch.load(model_ckpt) for name, engine in self.engines.items():
self.model = get_models(cfg.model.get(), training=False)[0] if self.dtype != torch.int8:
engine.to(self.device, dtype=self.dtype if not self.amp else torch.float32)
if "userdata" in state and 'symmap' in state['userdata']:
self.symmap = state['userdata']['symmap']
elif "symmap" in state:
self.symmap = state['symmap']
if "module" in state: self.engines.eval()
state = state['module']
self.model.load_state_dict(state)
if cfg.inference.backend == "local" and deepspeed_available and cfg.trainer.deepspeed.inferencing:
self.model = deepspeed.init_inference(model=self.model, mp_size=1, replace_with_kernel_inject=True, dtype=dtype if not amp else torch.float32).module
else:
engines = load_engines(training=False)
for name, engine in engines.items():
self.model = engine.module
break
if self.dtype != torch.int8:
self.model = self.model.to(self.device, dtype=self.dtype if not self.amp else torch.float32)
self.model.eval()
if self.symmap is None: if self.symmap is None:
self.symmap = get_phone_symmap() self.symmap = get_phone_symmap()
@ -159,6 +139,15 @@ class TTS():
wavs = [] wavs = []
sr = None sr = None
model_ar = None
model_nar = None
for name, engine in self.engines.items():
if "ar" in engine.hyper_config.capabilities:
model_ar = engine.module
if "nar" in engine.hyper_config.capabilities:
model_nar = engine.module
for line in lines: for line in lines:
if out_path is None: if out_path is None:
out_path = f"./data/{cfg.start_time}.wav" out_path = f"./data/{cfg.start_time}.wav"
@ -172,7 +161,7 @@ class TTS():
lang = to_device(lang, self.device).to(torch.uint8) lang = to_device(lang, self.device).to(torch.uint8)
with torch.autocast("cuda", dtype=self.dtype, enabled=self.amp): with torch.autocast("cuda", dtype=self.dtype, enabled=self.amp):
resps_list = self.model( resps_list = model_ar(
text_list=[phns], proms_list=[prom], lang_list=[lang], max_steps=max_ar_steps, max_resp_context=max_ar_context, text_list=[phns], proms_list=[prom], lang_list=[lang], max_steps=max_ar_steps, max_resp_context=max_ar_context,
sampling_temperature=ar_temp, sampling_temperature=ar_temp,
sampling_min_temperature=min_ar_temp, sampling_min_temperature=min_ar_temp,
@ -183,8 +172,7 @@ class TTS():
sampling_mirostat_tau=mirostat_tau, sampling_mirostat_tau=mirostat_tau,
sampling_mirostat_eta=mirostat_eta, sampling_mirostat_eta=mirostat_eta,
) )
resps_list = [r.unsqueeze(-1) for r in resps_list] resps_list = model_nar(
resps_list = self.model(
text_list=[phns], proms_list=[prom], lang_list=[lang], resps_list=resps_list, text_list=[phns], proms_list=[prom], lang_list=[lang], resps_list=resps_list,
max_levels=max_nar_levels, max_levels=max_nar_levels,
sampling_temperature=nar_temp, sampling_temperature=nar_temp,

View File

@ -1,37 +1,36 @@
def get_model(cfg, training=True): def get_model(config, training=True):
name = cfg.name name = config.name
if not cfg.experimental: if not config.experimental:
from .ar_nar import AR_NAR from .ar_nar import AR_NAR
model = AR_NAR( model = AR_NAR(
n_text_tokens=cfg.text_tokens, n_text_tokens=config.text_tokens,
n_audio_tokens=cfg.audio_tokens, n_audio_tokens=config.audio_tokens,
d_model=cfg.dim, d_model=config.dim,
n_heads=cfg.heads, n_heads=config.heads,
n_layers=cfg.layers, n_layers=config.layers,
n_experts=cfg.experts, n_experts=config.experts,
p_dropout=cfg.dropout, p_dropout=config.dropout,
l_padding = cfg.input_alignment, l_padding = config.input_alignment,
training = training, training = training,
config = cfg, config = config,
) )
model._cfg = cfg
else: else:
from .experimental import Model as Experimental from .experimental import Model as Experimental
model = Experimental( model = Experimental(
n_text_tokens=cfg.text_tokens, n_text_tokens=config.text_tokens,
n_audio_tokens=cfg.audio_tokens, n_audio_tokens=config.audio_tokens,
d_model=cfg.dim,
n_layers=cfg.layers,
n_heads=cfg.heads,
p_dropout=cfg.dropout,
config = cfg, d_model=config.dim,
n_layers=config.layers,
n_heads=config.heads,
p_dropout=config.dropout,
config = config,
) )
print(f"{name} ({next(model.parameters()).dtype}): {sum(p.numel() for p in model.parameters() if p.requires_grad)} parameters") print(f"{name} ({next(model.parameters()).dtype}): {sum(p.numel() for p in model.parameters() if p.requires_grad)} parameters")

View File

@ -16,7 +16,7 @@ class AR_NAR(Base):
@property @property
def causal(self): def causal(self):
if hasattr(self, "config") and self.config: if hasattr(self, "config") and self.config:
return "ar" in self.capabilities return "ar" in self.config.capabilities
return True return True
@property @property
@ -31,6 +31,8 @@ class AR_NAR(Base):
@property @property
def n_prom_levels(self) -> int: def n_prom_levels(self) -> int:
if hasattr(self, "config") and self.config:
return self.config.prom_levels
return cfg.model.prom_levels return cfg.model.prom_levels
@property @property
@ -41,18 +43,26 @@ class AR_NAR(Base):
@property @property
def n_max_levels(self) -> int: def n_max_levels(self) -> int:
if hasattr(self, "config") and self.config:
return self.config.max_levels
return cfg.model.max_levels return cfg.model.max_levels
@property @property
def n_tasks(self) -> int: def n_tasks(self) -> int:
if hasattr(self, "config") and self.config:
return self.config.tasks
return cfg.model.tasks return cfg.model.tasks
@property @property
def n_langs(self) -> int: def n_langs(self) -> int:
if hasattr(self, "config") and self.config:
return self.config.langs
return cfg.model.langs return cfg.model.langs
@property @property
def n_tones(self) -> int: def n_tones(self) -> int:
if hasattr(self, "config") and self.config:
return self.config.tones
return cfg.model.tones return cfg.model.tones
@property @property

View File

@ -1,5 +1,7 @@
# https://github.com/kyegomez/BitNet # https://github.com/kyegomez/BitNet
from torch import Tensor, nn from torch import Tensor, nn
from torch.utils.checkpoint import checkpoint
from bitnet.bit_transformer import Transformer as BitNetTransformerBlock, RMSNorm as BitNetRMSNorm from bitnet.bit_transformer import Transformer as BitNetTransformerBlock, RMSNorm as BitNetRMSNorm
# re-enable logging because zetascale fucking sucks # re-enable logging because zetascale fucking sucks

View File

@ -1,4 +1,6 @@
# https://github.com/state-spaces/mamba # https://github.com/state-spaces/mamba
from torch.utils.checkpoint import checkpoint
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel, MambaConfig, MixerModel as MambaMixelModel, layer_norm_fn as MambaLayerNormFn, RMSNorm as MambaRMSNorm from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel, MambaConfig, MixerModel as MambaMixelModel, layer_norm_fn as MambaLayerNormFn, RMSNorm as MambaRMSNorm
def MambaMixelModel_forward(self, input_ids=None, hidden_states=None, inference_params=None, **mixer_kwargs): def MambaMixelModel_forward(self, input_ids=None, hidden_states=None, inference_params=None, **mixer_kwargs):

View File

@ -1,12 +1,13 @@
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/mixtral/modeling_mixtral.py # https://github.com/huggingface/transformers/blob/main/src/transformers/models/mixtral/modeling_mixtral.py
import torch import torch
import torch.nn.functional as F
from transformers import MixtralModel, MixtralConfig from transformers import MixtralModel, MixtralConfig
from transformers.models.mixtral.modeling_mixtral import load_balancing_loss_func, MixtralSparseMoeBlock from transformers.models.mixtral.modeling_mixtral import load_balancing_loss_func, MixtralSparseMoeBlock
# This is required because batch sizes > 1 throws errors # This is required because batch sizes > 1 throws errors
def Fixed_MixtralSparseMoeBlock_forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def MixtralSparseMoeBlock_forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
""" """ """ """
batch_size, sequence_length, hidden_dim = hidden_states.shape batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.reshape(-1, hidden_dim) # was view() hidden_states = hidden_states.reshape(-1, hidden_dim) # was view()
@ -41,5 +42,4 @@ def Fixed_MixtralSparseMoeBlock_forward(self, hidden_states: torch.Tensor) -> to
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
return final_hidden_states, router_logits return final_hidden_states, router_logits
Original_MixtralSparseMoeBlock_forward = MixtralSparseMoeBlock.forward MixtralSparseMoeBlock.forward = MixtralSparseMoeBlock_forward
MixtralSparseMoeBlock.forward = Fixed_MixtralSparseMoeBlock_forward

View File

@ -207,9 +207,9 @@ class Base(nn.Module):
return -100 return -100
def loss_factor(self, k): def loss_factor(self, k):
if self.hyper_config is None: if self.config is None:
return 1.0 return 1.0
return self.hyper_config.loss_factors[k] if k in self.hyper_config.loss_factors else 1.0 return self.config.loss_factors[k] if k in self.config.loss_factors else 1.0
def __init__( def __init__(
self, self,
@ -231,8 +231,8 @@ class Base(nn.Module):
): ):
super().__init__() super().__init__()
self.training = training self.training = training
self.hyper_config = config self.config = config
self.gradient_checkpointing = self.hyper_config.gradient_checkpointing if self.hyper_config is not None else True self.gradient_checkpointing = self.config.gradient_checkpointing if self.config is not None else True
self.n_text_tokens = n_text_tokens self.n_text_tokens = n_text_tokens
self.n_audio_tokens = n_audio_tokens self.n_audio_tokens = n_audio_tokens
@ -246,7 +246,7 @@ class Base(nn.Module):
# +1 to include the stop token # +1 to include the stop token
n_prom_tokens = n_audio_tokens n_prom_tokens = n_audio_tokens
n_resp_tokens = n_audio_tokens + (1 if self.causal else 0) # AR requires a stop token to... know when to stop n_resp_tokens = n_audio_tokens + 1 # (1 if self.causal else 0) interoperability
self.text_emb = Embedding(n_text_tokens, d_model) self.text_emb = Embedding(n_text_tokens, d_model)
self.langs_emb = None self.langs_emb = None
@ -263,13 +263,13 @@ class Base(nn.Module):
self.proms_emb = AudioEmbedding( self.proms_emb = AudioEmbedding(
[n_prom_tokens] * self.n_prom_levels, d_model, [n_prom_tokens] * self.n_prom_levels, d_model,
levels=self.n_prom_levels if self.version > 3 else None, levels=self.n_prom_levels if self.version > 3 else None,
sums=self.hyper_config.audio_embedding_sums if self.hyper_config is not None else True, sums=self.config.audio_embedding_sums if self.config is not None else True,
) )
# [1024 + STOP] + [1024] * 8 # [1024 + STOP] + [1024] * 8
self.resps_emb = AudioEmbedding( self.resps_emb = AudioEmbedding(
[n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1), d_model, [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1), d_model,
levels=self.n_resp_levels if self.version > 3 else None, levels=self.n_resp_levels if self.version > 3 else None,
sums=self.hyper_config.audio_embedding_sums if self.hyper_config is not None else True sums=self.config.audio_embedding_sums if self.config is not None else True
) )
# useless since I actually removed using these with the input processing overhaul... # useless since I actually removed using these with the input processing overhaul...
@ -290,20 +290,20 @@ class Base(nn.Module):
self.sep = nn.Parameter(torch.randn(d_model)) self.sep = nn.Parameter(torch.randn(d_model))
# ick, there has to be a better way # ick, there has to be a better way
hf_attention = self.hyper_config.attention if self.hyper_config is not None else None hf_attention = self.config.attention if self.config is not None else None
if self.hyper_config.attention == "auto": if self.config.attention == "auto":
if "flash" in AVAILABLE_ATTENTIONS: if "flash" in AVAILABLE_ATTENTIONS:
self.hyper_config.attention = "flash" self.config.attention = "flash"
elif "xformers" in AVAILABLE_ATTENTIONS: elif "xformers" in AVAILABLE_ATTENTIONS:
self.hyper_config.attention = "xformers" self.config.attention = "xformers"
else: else:
self.hyper_config.attention = "mem_efficient" self.config.attention = "mem_efficient"
if self.hyper_config.attention in ["xformers", "mem_efficient", "math", "flash"]: if self.config.attention in ["xformers", "mem_efficient", "math", "flash"]:
hf_attention = None hf_attention = None
if self.hyper_config.attention not in AVAILABLE_ATTENTIONS: if self.config.attention not in AVAILABLE_ATTENTIONS:
raise ValueError(f"Requesting attention `{self.hyper_config.attention}` but is not available. Currently available: {AVAILABLE_ATTENTIONS}") raise ValueError(f"Requesting attention `{self.config.attention}` but is not available. Currently available: {AVAILABLE_ATTENTIONS}")
if self.arch_type == "transformer": if self.arch_type == "transformer":
@ -326,7 +326,7 @@ class Base(nn.Module):
num_hidden_layers=n_layers, num_hidden_layers=n_layers,
num_attention_heads=n_heads, num_attention_heads=n_heads,
attention_dropout=p_dropout if training else 0.0, attention_dropout=p_dropout if training else 0.0,
num_key_value_heads=self.hyper_config.kv_heads if self.hyper_config.kv_heads > 0 else n_heads, num_key_value_heads=self.config.kv_heads if self.config.kv_heads > 0 else n_heads,
hidden_act="gelu", hidden_act="gelu",
is_encoder_decoder=False, is_encoder_decoder=False,
is_decoder=True, is_decoder=True,
@ -342,7 +342,7 @@ class Base(nn.Module):
num_hidden_layers=n_layers, num_hidden_layers=n_layers,
num_attention_heads=n_heads, num_attention_heads=n_heads,
attention_dropout=p_dropout if training else 0.0, attention_dropout=p_dropout if training else 0.0,
num_key_value_heads=self.hyper_config.kv_heads if self.hyper_config.kv_heads > 0 else n_heads, num_key_value_heads=self.config.kv_heads if self.config.kv_heads > 0 else n_heads,
sliding_window=75 * 12, # 12 second context window sliding_window=75 * 12, # 12 second context window
output_router_logits=training, output_router_logits=training,
hidden_act="gelu", hidden_act="gelu",
@ -492,8 +492,8 @@ class Base(nn.Module):
else: else:
raise RuntimeError(f'Unknown arch specified: {self.arch_type}') raise RuntimeError(f'Unknown arch specified: {self.arch_type}')
if self.hyper_config.attention in ["xformers", "auto", "mem_efficient", "math", "flash"]: if self.config.attention in ["xformers", "auto", "mem_efficient", "math", "flash"]:
self.model = ml.replace_attention( self.model, klass=LlamaAttention, target=LlamaAttention_Base, mode=self.hyper_config.attention ) self.model = ml.replace_attention( self.model, klass=LlamaAttention, target=LlamaAttention_Base, mode=self.config.attention )
self.classifier = nn.Linear(d_model, n_resp_tokens) self.classifier = nn.Linear(d_model, n_resp_tokens)
@ -691,7 +691,7 @@ class Base(nn.Module):
quant_levels: Tensor | None = None, quant_levels: Tensor | None = None,
): ):
# old, "naive" way, no loss factoring # old, "naive" way, no loss factoring
if not self.hyper_config.loss_factors: if not self.config.loss_factors:
target_list = [] target_list = []
for batch_index, batch in enumerate(inputs): for batch_index, batch in enumerate(inputs):
target = [] target = []

View File

@ -54,13 +54,12 @@ def init_tts(restart=False):
parser = argparse.ArgumentParser(allow_abbrev=False) parser = argparse.ArgumentParser(allow_abbrev=False)
parser.add_argument("--yaml", type=Path, default=os.environ.get('VALLE_YAML', None)) # os environ so it can be specified in a HuggingFace Space too parser.add_argument("--yaml", type=Path, default=os.environ.get('VALLE_YAML', None)) # os environ so it can be specified in a HuggingFace Space too
parser.add_argument("--model-ckpt", type=Path, default=None)
parser.add_argument("--device", type=str, default="cuda") parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--amp", action="store_true") parser.add_argument("--amp", action="store_true")
parser.add_argument("--dtype", type=str, default="auto") parser.add_argument("--dtype", type=str, default="auto")
args, unknown = parser.parse_known_args() args, unknown = parser.parse_known_args()
tts = TTS( config=args.yaml, model_ckpt=args.model_ckpt, device=args.device, dtype=args.dtype if args.dtype != "auto" else None, amp=args.amp ) tts = TTS( config=args.yaml, device=args.device, dtype=args.dtype if args.dtype != "auto" else None, amp=args.amp )
return tts return tts
@gradio_wrapper(inputs=layout["inference"]["inputs"].keys()) @gradio_wrapper(inputs=layout["inference"]["inputs"].keys())