re-added loading multiple models because I'm now entertaining having split AR/NAR models again (and need a way to load both at once)
This commit is contained in:
parent
b05a905b95
commit
b2194b859a
@ -13,7 +13,6 @@ def main():
|
|||||||
parser.add_argument("--out-path", type=Path, default=None)
|
parser.add_argument("--out-path", type=Path, default=None)
|
||||||
|
|
||||||
parser.add_argument("--yaml", type=Path, default=None)
|
parser.add_argument("--yaml", type=Path, default=None)
|
||||||
parser.add_argument("--model-ckpt", type=Path, default=None)
|
|
||||||
|
|
||||||
parser.add_argument("--max-ar-steps", type=int, default=6 * 75)
|
parser.add_argument("--max-ar-steps", type=int, default=6 * 75)
|
||||||
parser.add_argument("--max-nar-levels", type=int, default=7)
|
parser.add_argument("--max-nar-levels", type=int, default=7)
|
||||||
@ -40,7 +39,7 @@ def main():
|
|||||||
parser.add_argument("--dtype", type=str, default=None)
|
parser.add_argument("--dtype", type=str, default=None)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
tts = TTS( config=args.yaml, model_ckpt=args.model_ckpt, device=args.device, dtype=args.dtype, amp=args.amp )
|
tts = TTS( config=args.yaml, device=args.device, dtype=args.dtype, amp=args.amp )
|
||||||
tts.inference(
|
tts.inference(
|
||||||
text=args.text,
|
text=args.text,
|
||||||
references=args.references,
|
references=args.references,
|
||||||
|
|||||||
@ -256,7 +256,7 @@ class Model:
|
|||||||
if self.interleave:
|
if self.interleave:
|
||||||
name.append("interleaved")
|
name.append("interleaved")
|
||||||
else:
|
else:
|
||||||
name.append(f'{cfg.model.prom_levels}')
|
name.append(f'{self.prom_levels}')
|
||||||
|
|
||||||
|
|
||||||
return "-".join(name)
|
return "-".join(name)
|
||||||
@ -627,8 +627,7 @@ class Config(_Config):
|
|||||||
experimental: bool = False # So I can stop commenting out things when committing
|
experimental: bool = False # So I can stop commenting out things when committing
|
||||||
|
|
||||||
dataset: Dataset = field(default_factory=lambda: Dataset)
|
dataset: Dataset = field(default_factory=lambda: Dataset)
|
||||||
model: Model = field(default_factory=lambda: Model)
|
models: dict | list | None = field(default_factory=lambda: [Model])
|
||||||
models: dict | list | None = None # deprecated
|
|
||||||
hyperparameters: Hyperparameters = field(default_factory=lambda: Hyperparameters)
|
hyperparameters: Hyperparameters = field(default_factory=lambda: Hyperparameters)
|
||||||
evaluation: Evaluation = field(default_factory=lambda: Evaluation)
|
evaluation: Evaluation = field(default_factory=lambda: Evaluation)
|
||||||
trainer: Trainer = field(default_factory=lambda: Trainer)
|
trainer: Trainer = field(default_factory=lambda: Trainer)
|
||||||
@ -643,6 +642,14 @@ class Config(_Config):
|
|||||||
|
|
||||||
audio_backend: str = "vocos"
|
audio_backend: str = "vocos"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model(self):
|
||||||
|
for i, model in enumerate(self.models):
|
||||||
|
if model.training:
|
||||||
|
return model
|
||||||
|
|
||||||
|
return self.models[0]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def distributed(self):
|
def distributed(self):
|
||||||
return world_size() > 1
|
return world_size() > 1
|
||||||
@ -681,8 +688,8 @@ class Config(_Config):
|
|||||||
if isinstance(self.dataset, type):
|
if isinstance(self.dataset, type):
|
||||||
self.dataset = dict()
|
self.dataset = dict()
|
||||||
|
|
||||||
if isinstance(self.model, type):
|
if isinstance(self.models, type):
|
||||||
self.model = dict()
|
self.models = dict()
|
||||||
|
|
||||||
if isinstance(self.hyperparameters, type):
|
if isinstance(self.hyperparameters, type):
|
||||||
self.hyperparameters = dict()
|
self.hyperparameters = dict()
|
||||||
@ -704,10 +711,14 @@ class Config(_Config):
|
|||||||
self.dataset.validation = [ Path(dir) for dir in self.dataset.validation ]
|
self.dataset.validation = [ Path(dir) for dir in self.dataset.validation ]
|
||||||
self.dataset.noise = [ Path(dir) for dir in self.dataset.noise ]
|
self.dataset.noise = [ Path(dir) for dir in self.dataset.noise ]
|
||||||
|
|
||||||
|
"""
|
||||||
if self.models is not None:
|
if self.models is not None:
|
||||||
self.model = Model(**next(iter(self.models)))
|
self.model = Model(**next(iter(self.models)))
|
||||||
else:
|
else:
|
||||||
self.model = Model(**self.model)
|
self.model = Model(**self.model)
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.models = [ Model(**model) for model in self.models ]
|
||||||
|
|
||||||
self.hyperparameters = Hyperparameters(**self.hyperparameters)
|
self.hyperparameters = Hyperparameters(**self.hyperparameters)
|
||||||
|
|
||||||
|
|||||||
@ -26,14 +26,14 @@ from functools import cache
|
|||||||
|
|
||||||
@cache
|
@cache
|
||||||
def load_engines(training=True):
|
def load_engines(training=True):
|
||||||
models = get_models(cfg.model.get(), training=training)
|
models = get_models(cfg.models, training=training)
|
||||||
engines = dict()
|
engines = dict()
|
||||||
|
|
||||||
for name, model in models.items():
|
for name, model in models.items():
|
||||||
optimizer = None
|
optimizer = None
|
||||||
lr_scheduler = None
|
lr_scheduler = None
|
||||||
|
|
||||||
inferencing = cfg.mode == "inferencing" or not model.hyper_config.training
|
inferencing = cfg.mode == "inferencing" or not model.config.training
|
||||||
backend = cfg.inference.backend if inferencing else cfg.trainer.backend
|
backend = cfg.inference.backend if inferencing else cfg.trainer.backend
|
||||||
dtype = cfg.inference.dtype if inferencing else cfg.trainer.dtype
|
dtype = cfg.inference.dtype if inferencing else cfg.trainer.dtype
|
||||||
amp = cfg.inference.amp if inferencing else cfg.trainer.amp
|
amp = cfg.inference.amp if inferencing else cfg.trainer.amp
|
||||||
@ -43,7 +43,7 @@ def load_engines(training=True):
|
|||||||
engine_class = _Engine if backend == "local" or inferencing else Engine
|
engine_class = _Engine if backend == "local" or inferencing else Engine
|
||||||
|
|
||||||
if inferencing:
|
if inferencing:
|
||||||
model.hyper_config.training = False
|
model.config.training = False
|
||||||
|
|
||||||
if cfg.optimizations.replace and cfg.optimizations.linear:
|
if cfg.optimizations.replace and cfg.optimizations.linear:
|
||||||
model.model = ml.replace_linear( model.model )
|
model.model = ml.replace_linear( model.model )
|
||||||
@ -83,7 +83,7 @@ def load_engines(training=True):
|
|||||||
params.update(cfg.hyperparameters.optimizer_params)
|
params.update(cfg.hyperparameters.optimizer_params)
|
||||||
|
|
||||||
optimizer = optimizer_class(
|
optimizer = optimizer_class(
|
||||||
[ param for name, param in model.named_parameters() if name not in model.hyper_config.frozen_params ],
|
[ param for name, param in model.named_parameters() if name not in model.config.frozen_params ],
|
||||||
**params,
|
**params,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -96,7 +96,7 @@ def load_engines(training=True):
|
|||||||
raise ValueError(f'ScheduleFree not implemented with requested optimizer: {cfg.hyperparameters.optimizer}')
|
raise ValueError(f'ScheduleFree not implemented with requested optimizer: {cfg.hyperparameters.optimizer}')
|
||||||
|
|
||||||
optimizer = scheduler_class(
|
optimizer = scheduler_class(
|
||||||
[ param for name, param in model.named_parameters() if name not in model.hyper_config.frozen_params ],
|
[ param for name, param in model.named_parameters() if name not in model.config.frozen_params ],
|
||||||
lr = params['lr'],
|
lr = params['lr'],
|
||||||
warmup_steps = cfg.hyperparameters.warmup_steps
|
warmup_steps = cfg.hyperparameters.warmup_steps
|
||||||
)
|
)
|
||||||
@ -143,12 +143,16 @@ def load_engines(training=True):
|
|||||||
del state[k]
|
del state[k]
|
||||||
|
|
||||||
# resize text embedding
|
# resize text embedding
|
||||||
if cfg.model.text_tokens != state["text_emb.weight"].shape[0]:
|
if model.config.text_tokens != state["text_emb.weight"].shape[0]:
|
||||||
state["text_emb.weight"] = state["text_emb.weight"][:cfg.model.text_tokens]
|
state["text_emb.weight"] = state["text_emb.weight"][:model.config.text_tokens]
|
||||||
|
|
||||||
|
# resize text embedding
|
||||||
|
if model.config.resp_levels != state["rvq_level_emb.weight"].shape[0]:
|
||||||
|
state["rvq_level_emb.weight"] = state["rvq_level_emb.weight"][:model.config.resp_levels]
|
||||||
|
|
||||||
model.load_state_dict(state, strict=cfg.trainer.strict_loading)
|
model.load_state_dict(state, strict=cfg.trainer.strict_loading)
|
||||||
|
|
||||||
hyper_config = model.hyper_config
|
hyper_config = model.config
|
||||||
|
|
||||||
# wrap if DDP is requested
|
# wrap if DDP is requested
|
||||||
if ddp:
|
if ddp:
|
||||||
|
|||||||
@ -19,7 +19,7 @@ if deepspeed_available:
|
|||||||
import deepspeed
|
import deepspeed
|
||||||
|
|
||||||
class TTS():
|
class TTS():
|
||||||
def __init__( self, config=None, model_ckpt=None, device=None, amp=None, dtype=None ):
|
def __init__( self, config=None, device=None, amp=None, dtype=None ):
|
||||||
self.loading = True
|
self.loading = True
|
||||||
|
|
||||||
self.input_sample_rate = 24000
|
self.input_sample_rate = 24000
|
||||||
@ -53,32 +53,12 @@ class TTS():
|
|||||||
|
|
||||||
self.symmap = None
|
self.symmap = None
|
||||||
|
|
||||||
if model_ckpt:
|
self.engines = load_engines(training=False)
|
||||||
state = torch.load(model_ckpt)
|
for name, engine in self.engines.items():
|
||||||
self.model = get_models(cfg.model.get(), training=False)[0]
|
if self.dtype != torch.int8:
|
||||||
|
engine.to(self.device, dtype=self.dtype if not self.amp else torch.float32)
|
||||||
if "userdata" in state and 'symmap' in state['userdata']:
|
|
||||||
self.symmap = state['userdata']['symmap']
|
|
||||||
elif "symmap" in state:
|
|
||||||
self.symmap = state['symmap']
|
|
||||||
|
|
||||||
if "module" in state:
|
self.engines.eval()
|
||||||
state = state['module']
|
|
||||||
|
|
||||||
self.model.load_state_dict(state)
|
|
||||||
|
|
||||||
if cfg.inference.backend == "local" and deepspeed_available and cfg.trainer.deepspeed.inferencing:
|
|
||||||
self.model = deepspeed.init_inference(model=self.model, mp_size=1, replace_with_kernel_inject=True, dtype=dtype if not amp else torch.float32).module
|
|
||||||
else:
|
|
||||||
engines = load_engines(training=False)
|
|
||||||
for name, engine in engines.items():
|
|
||||||
self.model = engine.module
|
|
||||||
break
|
|
||||||
|
|
||||||
if self.dtype != torch.int8:
|
|
||||||
self.model = self.model.to(self.device, dtype=self.dtype if not self.amp else torch.float32)
|
|
||||||
|
|
||||||
self.model.eval()
|
|
||||||
|
|
||||||
if self.symmap is None:
|
if self.symmap is None:
|
||||||
self.symmap = get_phone_symmap()
|
self.symmap = get_phone_symmap()
|
||||||
@ -159,6 +139,15 @@ class TTS():
|
|||||||
wavs = []
|
wavs = []
|
||||||
sr = None
|
sr = None
|
||||||
|
|
||||||
|
model_ar = None
|
||||||
|
model_nar = None
|
||||||
|
|
||||||
|
for name, engine in self.engines.items():
|
||||||
|
if "ar" in engine.hyper_config.capabilities:
|
||||||
|
model_ar = engine.module
|
||||||
|
if "nar" in engine.hyper_config.capabilities:
|
||||||
|
model_nar = engine.module
|
||||||
|
|
||||||
for line in lines:
|
for line in lines:
|
||||||
if out_path is None:
|
if out_path is None:
|
||||||
out_path = f"./data/{cfg.start_time}.wav"
|
out_path = f"./data/{cfg.start_time}.wav"
|
||||||
@ -172,7 +161,7 @@ class TTS():
|
|||||||
lang = to_device(lang, self.device).to(torch.uint8)
|
lang = to_device(lang, self.device).to(torch.uint8)
|
||||||
|
|
||||||
with torch.autocast("cuda", dtype=self.dtype, enabled=self.amp):
|
with torch.autocast("cuda", dtype=self.dtype, enabled=self.amp):
|
||||||
resps_list = self.model(
|
resps_list = model_ar(
|
||||||
text_list=[phns], proms_list=[prom], lang_list=[lang], max_steps=max_ar_steps, max_resp_context=max_ar_context,
|
text_list=[phns], proms_list=[prom], lang_list=[lang], max_steps=max_ar_steps, max_resp_context=max_ar_context,
|
||||||
sampling_temperature=ar_temp,
|
sampling_temperature=ar_temp,
|
||||||
sampling_min_temperature=min_ar_temp,
|
sampling_min_temperature=min_ar_temp,
|
||||||
@ -183,8 +172,7 @@ class TTS():
|
|||||||
sampling_mirostat_tau=mirostat_tau,
|
sampling_mirostat_tau=mirostat_tau,
|
||||||
sampling_mirostat_eta=mirostat_eta,
|
sampling_mirostat_eta=mirostat_eta,
|
||||||
)
|
)
|
||||||
resps_list = [r.unsqueeze(-1) for r in resps_list]
|
resps_list = model_nar(
|
||||||
resps_list = self.model(
|
|
||||||
text_list=[phns], proms_list=[prom], lang_list=[lang], resps_list=resps_list,
|
text_list=[phns], proms_list=[prom], lang_list=[lang], resps_list=resps_list,
|
||||||
max_levels=max_nar_levels,
|
max_levels=max_nar_levels,
|
||||||
sampling_temperature=nar_temp,
|
sampling_temperature=nar_temp,
|
||||||
|
|||||||
@ -1,37 +1,36 @@
|
|||||||
|
|
||||||
def get_model(cfg, training=True):
|
def get_model(config, training=True):
|
||||||
name = cfg.name
|
name = config.name
|
||||||
|
|
||||||
if not cfg.experimental:
|
if not config.experimental:
|
||||||
from .ar_nar import AR_NAR
|
from .ar_nar import AR_NAR
|
||||||
model = AR_NAR(
|
model = AR_NAR(
|
||||||
n_text_tokens=cfg.text_tokens,
|
n_text_tokens=config.text_tokens,
|
||||||
n_audio_tokens=cfg.audio_tokens,
|
n_audio_tokens=config.audio_tokens,
|
||||||
d_model=cfg.dim,
|
d_model=config.dim,
|
||||||
n_heads=cfg.heads,
|
n_heads=config.heads,
|
||||||
n_layers=cfg.layers,
|
n_layers=config.layers,
|
||||||
n_experts=cfg.experts,
|
n_experts=config.experts,
|
||||||
|
|
||||||
p_dropout=cfg.dropout,
|
p_dropout=config.dropout,
|
||||||
|
|
||||||
l_padding = cfg.input_alignment,
|
l_padding = config.input_alignment,
|
||||||
|
|
||||||
training = training,
|
training = training,
|
||||||
config = cfg,
|
config = config,
|
||||||
)
|
)
|
||||||
model._cfg = cfg
|
|
||||||
else:
|
else:
|
||||||
from .experimental import Model as Experimental
|
from .experimental import Model as Experimental
|
||||||
model = Experimental(
|
model = Experimental(
|
||||||
n_text_tokens=cfg.text_tokens,
|
n_text_tokens=config.text_tokens,
|
||||||
n_audio_tokens=cfg.audio_tokens,
|
n_audio_tokens=config.audio_tokens,
|
||||||
|
|
||||||
d_model=cfg.dim,
|
|
||||||
n_layers=cfg.layers,
|
|
||||||
n_heads=cfg.heads,
|
|
||||||
p_dropout=cfg.dropout,
|
|
||||||
|
|
||||||
config = cfg,
|
d_model=config.dim,
|
||||||
|
n_layers=config.layers,
|
||||||
|
n_heads=config.heads,
|
||||||
|
p_dropout=config.dropout,
|
||||||
|
|
||||||
|
config = config,
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"{name} ({next(model.parameters()).dtype}): {sum(p.numel() for p in model.parameters() if p.requires_grad)} parameters")
|
print(f"{name} ({next(model.parameters()).dtype}): {sum(p.numel() for p in model.parameters() if p.requires_grad)} parameters")
|
||||||
|
|||||||
@ -16,7 +16,7 @@ class AR_NAR(Base):
|
|||||||
@property
|
@property
|
||||||
def causal(self):
|
def causal(self):
|
||||||
if hasattr(self, "config") and self.config:
|
if hasattr(self, "config") and self.config:
|
||||||
return "ar" in self.capabilities
|
return "ar" in self.config.capabilities
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -31,6 +31,8 @@ class AR_NAR(Base):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def n_prom_levels(self) -> int:
|
def n_prom_levels(self) -> int:
|
||||||
|
if hasattr(self, "config") and self.config:
|
||||||
|
return self.config.prom_levels
|
||||||
return cfg.model.prom_levels
|
return cfg.model.prom_levels
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -41,18 +43,26 @@ class AR_NAR(Base):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def n_max_levels(self) -> int:
|
def n_max_levels(self) -> int:
|
||||||
|
if hasattr(self, "config") and self.config:
|
||||||
|
return self.config.max_levels
|
||||||
return cfg.model.max_levels
|
return cfg.model.max_levels
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def n_tasks(self) -> int:
|
def n_tasks(self) -> int:
|
||||||
|
if hasattr(self, "config") and self.config:
|
||||||
|
return self.config.tasks
|
||||||
return cfg.model.tasks
|
return cfg.model.tasks
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def n_langs(self) -> int:
|
def n_langs(self) -> int:
|
||||||
|
if hasattr(self, "config") and self.config:
|
||||||
|
return self.config.langs
|
||||||
return cfg.model.langs
|
return cfg.model.langs
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def n_tones(self) -> int:
|
def n_tones(self) -> int:
|
||||||
|
if hasattr(self, "config") and self.config:
|
||||||
|
return self.config.tones
|
||||||
return cfg.model.tones
|
return cfg.model.tones
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|||||||
@ -1,5 +1,7 @@
|
|||||||
# https://github.com/kyegomez/BitNet
|
# https://github.com/kyegomez/BitNet
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
from torch.utils.checkpoint import checkpoint
|
||||||
|
|
||||||
from bitnet.bit_transformer import Transformer as BitNetTransformerBlock, RMSNorm as BitNetRMSNorm
|
from bitnet.bit_transformer import Transformer as BitNetTransformerBlock, RMSNorm as BitNetRMSNorm
|
||||||
|
|
||||||
# re-enable logging because zetascale fucking sucks
|
# re-enable logging because zetascale fucking sucks
|
||||||
|
|||||||
@ -1,4 +1,6 @@
|
|||||||
# https://github.com/state-spaces/mamba
|
# https://github.com/state-spaces/mamba
|
||||||
|
from torch.utils.checkpoint import checkpoint
|
||||||
|
|
||||||
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel, MambaConfig, MixerModel as MambaMixelModel, layer_norm_fn as MambaLayerNormFn, RMSNorm as MambaRMSNorm
|
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel, MambaConfig, MixerModel as MambaMixelModel, layer_norm_fn as MambaLayerNormFn, RMSNorm as MambaRMSNorm
|
||||||
|
|
||||||
def MambaMixelModel_forward(self, input_ids=None, hidden_states=None, inference_params=None, **mixer_kwargs):
|
def MambaMixelModel_forward(self, input_ids=None, hidden_states=None, inference_params=None, **mixer_kwargs):
|
||||||
|
|||||||
@ -1,12 +1,13 @@
|
|||||||
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/mixtral/modeling_mixtral.py
|
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/mixtral/modeling_mixtral.py
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from transformers import MixtralModel, MixtralConfig
|
from transformers import MixtralModel, MixtralConfig
|
||||||
from transformers.models.mixtral.modeling_mixtral import load_balancing_loss_func, MixtralSparseMoeBlock
|
from transformers.models.mixtral.modeling_mixtral import load_balancing_loss_func, MixtralSparseMoeBlock
|
||||||
|
|
||||||
# This is required because batch sizes > 1 throws errors
|
# This is required because batch sizes > 1 throws errors
|
||||||
def Fixed_MixtralSparseMoeBlock_forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
def MixtralSparseMoeBlock_forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
""" """
|
""" """
|
||||||
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
||||||
hidden_states = hidden_states.reshape(-1, hidden_dim) # was view()
|
hidden_states = hidden_states.reshape(-1, hidden_dim) # was view()
|
||||||
@ -41,5 +42,4 @@ def Fixed_MixtralSparseMoeBlock_forward(self, hidden_states: torch.Tensor) -> to
|
|||||||
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
|
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
|
||||||
return final_hidden_states, router_logits
|
return final_hidden_states, router_logits
|
||||||
|
|
||||||
Original_MixtralSparseMoeBlock_forward = MixtralSparseMoeBlock.forward
|
MixtralSparseMoeBlock.forward = MixtralSparseMoeBlock_forward
|
||||||
MixtralSparseMoeBlock.forward = Fixed_MixtralSparseMoeBlock_forward
|
|
||||||
@ -207,9 +207,9 @@ class Base(nn.Module):
|
|||||||
return -100
|
return -100
|
||||||
|
|
||||||
def loss_factor(self, k):
|
def loss_factor(self, k):
|
||||||
if self.hyper_config is None:
|
if self.config is None:
|
||||||
return 1.0
|
return 1.0
|
||||||
return self.hyper_config.loss_factors[k] if k in self.hyper_config.loss_factors else 1.0
|
return self.config.loss_factors[k] if k in self.config.loss_factors else 1.0
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -231,8 +231,8 @@ class Base(nn.Module):
|
|||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.training = training
|
self.training = training
|
||||||
self.hyper_config = config
|
self.config = config
|
||||||
self.gradient_checkpointing = self.hyper_config.gradient_checkpointing if self.hyper_config is not None else True
|
self.gradient_checkpointing = self.config.gradient_checkpointing if self.config is not None else True
|
||||||
|
|
||||||
self.n_text_tokens = n_text_tokens
|
self.n_text_tokens = n_text_tokens
|
||||||
self.n_audio_tokens = n_audio_tokens
|
self.n_audio_tokens = n_audio_tokens
|
||||||
@ -246,7 +246,7 @@ class Base(nn.Module):
|
|||||||
|
|
||||||
# +1 to include the stop token
|
# +1 to include the stop token
|
||||||
n_prom_tokens = n_audio_tokens
|
n_prom_tokens = n_audio_tokens
|
||||||
n_resp_tokens = n_audio_tokens + (1 if self.causal else 0) # AR requires a stop token to... know when to stop
|
n_resp_tokens = n_audio_tokens + 1 # (1 if self.causal else 0) interoperability
|
||||||
|
|
||||||
self.text_emb = Embedding(n_text_tokens, d_model)
|
self.text_emb = Embedding(n_text_tokens, d_model)
|
||||||
self.langs_emb = None
|
self.langs_emb = None
|
||||||
@ -263,13 +263,13 @@ class Base(nn.Module):
|
|||||||
self.proms_emb = AudioEmbedding(
|
self.proms_emb = AudioEmbedding(
|
||||||
[n_prom_tokens] * self.n_prom_levels, d_model,
|
[n_prom_tokens] * self.n_prom_levels, d_model,
|
||||||
levels=self.n_prom_levels if self.version > 3 else None,
|
levels=self.n_prom_levels if self.version > 3 else None,
|
||||||
sums=self.hyper_config.audio_embedding_sums if self.hyper_config is not None else True,
|
sums=self.config.audio_embedding_sums if self.config is not None else True,
|
||||||
)
|
)
|
||||||
# [1024 + STOP] + [1024] * 8
|
# [1024 + STOP] + [1024] * 8
|
||||||
self.resps_emb = AudioEmbedding(
|
self.resps_emb = AudioEmbedding(
|
||||||
[n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1), d_model,
|
[n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1), d_model,
|
||||||
levels=self.n_resp_levels if self.version > 3 else None,
|
levels=self.n_resp_levels if self.version > 3 else None,
|
||||||
sums=self.hyper_config.audio_embedding_sums if self.hyper_config is not None else True
|
sums=self.config.audio_embedding_sums if self.config is not None else True
|
||||||
)
|
)
|
||||||
|
|
||||||
# useless since I actually removed using these with the input processing overhaul...
|
# useless since I actually removed using these with the input processing overhaul...
|
||||||
@ -290,20 +290,20 @@ class Base(nn.Module):
|
|||||||
self.sep = nn.Parameter(torch.randn(d_model))
|
self.sep = nn.Parameter(torch.randn(d_model))
|
||||||
|
|
||||||
# ick, there has to be a better way
|
# ick, there has to be a better way
|
||||||
hf_attention = self.hyper_config.attention if self.hyper_config is not None else None
|
hf_attention = self.config.attention if self.config is not None else None
|
||||||
|
|
||||||
if self.hyper_config.attention == "auto":
|
if self.config.attention == "auto":
|
||||||
if "flash" in AVAILABLE_ATTENTIONS:
|
if "flash" in AVAILABLE_ATTENTIONS:
|
||||||
self.hyper_config.attention = "flash"
|
self.config.attention = "flash"
|
||||||
elif "xformers" in AVAILABLE_ATTENTIONS:
|
elif "xformers" in AVAILABLE_ATTENTIONS:
|
||||||
self.hyper_config.attention = "xformers"
|
self.config.attention = "xformers"
|
||||||
else:
|
else:
|
||||||
self.hyper_config.attention = "mem_efficient"
|
self.config.attention = "mem_efficient"
|
||||||
|
|
||||||
if self.hyper_config.attention in ["xformers", "mem_efficient", "math", "flash"]:
|
if self.config.attention in ["xformers", "mem_efficient", "math", "flash"]:
|
||||||
hf_attention = None
|
hf_attention = None
|
||||||
if self.hyper_config.attention not in AVAILABLE_ATTENTIONS:
|
if self.config.attention not in AVAILABLE_ATTENTIONS:
|
||||||
raise ValueError(f"Requesting attention `{self.hyper_config.attention}` but is not available. Currently available: {AVAILABLE_ATTENTIONS}")
|
raise ValueError(f"Requesting attention `{self.config.attention}` but is not available. Currently available: {AVAILABLE_ATTENTIONS}")
|
||||||
|
|
||||||
|
|
||||||
if self.arch_type == "transformer":
|
if self.arch_type == "transformer":
|
||||||
@ -326,7 +326,7 @@ class Base(nn.Module):
|
|||||||
num_hidden_layers=n_layers,
|
num_hidden_layers=n_layers,
|
||||||
num_attention_heads=n_heads,
|
num_attention_heads=n_heads,
|
||||||
attention_dropout=p_dropout if training else 0.0,
|
attention_dropout=p_dropout if training else 0.0,
|
||||||
num_key_value_heads=self.hyper_config.kv_heads if self.hyper_config.kv_heads > 0 else n_heads,
|
num_key_value_heads=self.config.kv_heads if self.config.kv_heads > 0 else n_heads,
|
||||||
hidden_act="gelu",
|
hidden_act="gelu",
|
||||||
is_encoder_decoder=False,
|
is_encoder_decoder=False,
|
||||||
is_decoder=True,
|
is_decoder=True,
|
||||||
@ -342,7 +342,7 @@ class Base(nn.Module):
|
|||||||
num_hidden_layers=n_layers,
|
num_hidden_layers=n_layers,
|
||||||
num_attention_heads=n_heads,
|
num_attention_heads=n_heads,
|
||||||
attention_dropout=p_dropout if training else 0.0,
|
attention_dropout=p_dropout if training else 0.0,
|
||||||
num_key_value_heads=self.hyper_config.kv_heads if self.hyper_config.kv_heads > 0 else n_heads,
|
num_key_value_heads=self.config.kv_heads if self.config.kv_heads > 0 else n_heads,
|
||||||
sliding_window=75 * 12, # 12 second context window
|
sliding_window=75 * 12, # 12 second context window
|
||||||
output_router_logits=training,
|
output_router_logits=training,
|
||||||
hidden_act="gelu",
|
hidden_act="gelu",
|
||||||
@ -492,8 +492,8 @@ class Base(nn.Module):
|
|||||||
else:
|
else:
|
||||||
raise RuntimeError(f'Unknown arch specified: {self.arch_type}')
|
raise RuntimeError(f'Unknown arch specified: {self.arch_type}')
|
||||||
|
|
||||||
if self.hyper_config.attention in ["xformers", "auto", "mem_efficient", "math", "flash"]:
|
if self.config.attention in ["xformers", "auto", "mem_efficient", "math", "flash"]:
|
||||||
self.model = ml.replace_attention( self.model, klass=LlamaAttention, target=LlamaAttention_Base, mode=self.hyper_config.attention )
|
self.model = ml.replace_attention( self.model, klass=LlamaAttention, target=LlamaAttention_Base, mode=self.config.attention )
|
||||||
|
|
||||||
self.classifier = nn.Linear(d_model, n_resp_tokens)
|
self.classifier = nn.Linear(d_model, n_resp_tokens)
|
||||||
|
|
||||||
@ -691,7 +691,7 @@ class Base(nn.Module):
|
|||||||
quant_levels: Tensor | None = None,
|
quant_levels: Tensor | None = None,
|
||||||
):
|
):
|
||||||
# old, "naive" way, no loss factoring
|
# old, "naive" way, no loss factoring
|
||||||
if not self.hyper_config.loss_factors:
|
if not self.config.loss_factors:
|
||||||
target_list = []
|
target_list = []
|
||||||
for batch_index, batch in enumerate(inputs):
|
for batch_index, batch in enumerate(inputs):
|
||||||
target = []
|
target = []
|
||||||
|
|||||||
@ -54,13 +54,12 @@ def init_tts(restart=False):
|
|||||||
|
|
||||||
parser = argparse.ArgumentParser(allow_abbrev=False)
|
parser = argparse.ArgumentParser(allow_abbrev=False)
|
||||||
parser.add_argument("--yaml", type=Path, default=os.environ.get('VALLE_YAML', None)) # os environ so it can be specified in a HuggingFace Space too
|
parser.add_argument("--yaml", type=Path, default=os.environ.get('VALLE_YAML', None)) # os environ so it can be specified in a HuggingFace Space too
|
||||||
parser.add_argument("--model-ckpt", type=Path, default=None)
|
|
||||||
parser.add_argument("--device", type=str, default="cuda")
|
parser.add_argument("--device", type=str, default="cuda")
|
||||||
parser.add_argument("--amp", action="store_true")
|
parser.add_argument("--amp", action="store_true")
|
||||||
parser.add_argument("--dtype", type=str, default="auto")
|
parser.add_argument("--dtype", type=str, default="auto")
|
||||||
args, unknown = parser.parse_known_args()
|
args, unknown = parser.parse_known_args()
|
||||||
|
|
||||||
tts = TTS( config=args.yaml, model_ckpt=args.model_ckpt, device=args.device, dtype=args.dtype if args.dtype != "auto" else None, amp=args.amp )
|
tts = TTS( config=args.yaml, device=args.device, dtype=args.dtype if args.dtype != "auto" else None, amp=args.amp )
|
||||||
return tts
|
return tts
|
||||||
|
|
||||||
@gradio_wrapper(inputs=layout["inference"]["inputs"].keys())
|
@gradio_wrapper(inputs=layout["inference"]["inputs"].keys())
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user