diff --git a/vall_e/__main__.py b/vall_e/__main__.py index 5ccb1cc..e54954d 100755 --- a/vall_e/__main__.py +++ b/vall_e/__main__.py @@ -13,7 +13,6 @@ def main(): parser.add_argument("--out-path", type=Path, default=None) parser.add_argument("--yaml", type=Path, default=None) - parser.add_argument("--model-ckpt", type=Path, default=None) parser.add_argument("--max-ar-steps", type=int, default=6 * 75) parser.add_argument("--max-nar-levels", type=int, default=7) @@ -40,7 +39,7 @@ def main(): parser.add_argument("--dtype", type=str, default=None) args = parser.parse_args() - tts = TTS( config=args.yaml, model_ckpt=args.model_ckpt, device=args.device, dtype=args.dtype, amp=args.amp ) + tts = TTS( config=args.yaml, device=args.device, dtype=args.dtype, amp=args.amp ) tts.inference( text=args.text, references=args.references, diff --git a/vall_e/config.py b/vall_e/config.py index b2bf4a6..ba8889d 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -256,7 +256,7 @@ class Model: if self.interleave: name.append("interleaved") else: - name.append(f'{cfg.model.prom_levels}') + name.append(f'{self.prom_levels}') return "-".join(name) @@ -627,8 +627,7 @@ class Config(_Config): experimental: bool = False # So I can stop commenting out things when committing dataset: Dataset = field(default_factory=lambda: Dataset) - model: Model = field(default_factory=lambda: Model) - models: dict | list | None = None # deprecated + models: dict | list | None = field(default_factory=lambda: [Model]) hyperparameters: Hyperparameters = field(default_factory=lambda: Hyperparameters) evaluation: Evaluation = field(default_factory=lambda: Evaluation) trainer: Trainer = field(default_factory=lambda: Trainer) @@ -643,6 +642,14 @@ class Config(_Config): audio_backend: str = "vocos" + @property + def model(self): + for i, model in enumerate(self.models): + if model.training: + return model + + return self.models[0] + @property def distributed(self): return world_size() > 1 @@ -681,8 +688,8 @@ class Config(_Config): if isinstance(self.dataset, type): self.dataset = dict() - if isinstance(self.model, type): - self.model = dict() + if isinstance(self.models, type): + self.models = dict() if isinstance(self.hyperparameters, type): self.hyperparameters = dict() @@ -704,10 +711,14 @@ class Config(_Config): self.dataset.validation = [ Path(dir) for dir in self.dataset.validation ] self.dataset.noise = [ Path(dir) for dir in self.dataset.noise ] + """ if self.models is not None: self.model = Model(**next(iter(self.models))) else: self.model = Model(**self.model) + """ + + self.models = [ Model(**model) for model in self.models ] self.hyperparameters = Hyperparameters(**self.hyperparameters) diff --git a/vall_e/engines/__init__.py b/vall_e/engines/__init__.py index 4fb7f45..9f3e8cb 100755 --- a/vall_e/engines/__init__.py +++ b/vall_e/engines/__init__.py @@ -26,14 +26,14 @@ from functools import cache @cache def load_engines(training=True): - models = get_models(cfg.model.get(), training=training) + models = get_models(cfg.models, training=training) engines = dict() for name, model in models.items(): optimizer = None lr_scheduler = None - inferencing = cfg.mode == "inferencing" or not model.hyper_config.training + inferencing = cfg.mode == "inferencing" or not model.config.training backend = cfg.inference.backend if inferencing else cfg.trainer.backend dtype = cfg.inference.dtype if inferencing else cfg.trainer.dtype amp = cfg.inference.amp if inferencing else cfg.trainer.amp @@ -43,7 +43,7 @@ def load_engines(training=True): engine_class = _Engine if backend == "local" or inferencing else Engine if inferencing: - model.hyper_config.training = False + model.config.training = False if cfg.optimizations.replace and cfg.optimizations.linear: model.model = ml.replace_linear( model.model ) @@ -83,7 +83,7 @@ def load_engines(training=True): params.update(cfg.hyperparameters.optimizer_params) optimizer = optimizer_class( - [ param for name, param in model.named_parameters() if name not in model.hyper_config.frozen_params ], + [ param for name, param in model.named_parameters() if name not in model.config.frozen_params ], **params, ) @@ -96,7 +96,7 @@ def load_engines(training=True): raise ValueError(f'ScheduleFree not implemented with requested optimizer: {cfg.hyperparameters.optimizer}') optimizer = scheduler_class( - [ param for name, param in model.named_parameters() if name not in model.hyper_config.frozen_params ], + [ param for name, param in model.named_parameters() if name not in model.config.frozen_params ], lr = params['lr'], warmup_steps = cfg.hyperparameters.warmup_steps ) @@ -143,12 +143,16 @@ def load_engines(training=True): del state[k] # resize text embedding - if cfg.model.text_tokens != state["text_emb.weight"].shape[0]: - state["text_emb.weight"] = state["text_emb.weight"][:cfg.model.text_tokens] + if model.config.text_tokens != state["text_emb.weight"].shape[0]: + state["text_emb.weight"] = state["text_emb.weight"][:model.config.text_tokens] + + # resize text embedding + if model.config.resp_levels != state["rvq_level_emb.weight"].shape[0]: + state["rvq_level_emb.weight"] = state["rvq_level_emb.weight"][:model.config.resp_levels] model.load_state_dict(state, strict=cfg.trainer.strict_loading) - hyper_config = model.hyper_config + hyper_config = model.config # wrap if DDP is requested if ddp: diff --git a/vall_e/inference.py b/vall_e/inference.py index cd3cde0..07d93d9 100755 --- a/vall_e/inference.py +++ b/vall_e/inference.py @@ -19,7 +19,7 @@ if deepspeed_available: import deepspeed class TTS(): - def __init__( self, config=None, model_ckpt=None, device=None, amp=None, dtype=None ): + def __init__( self, config=None, device=None, amp=None, dtype=None ): self.loading = True self.input_sample_rate = 24000 @@ -53,32 +53,12 @@ class TTS(): self.symmap = None - if model_ckpt: - state = torch.load(model_ckpt) - self.model = get_models(cfg.model.get(), training=False)[0] - - if "userdata" in state and 'symmap' in state['userdata']: - self.symmap = state['userdata']['symmap'] - elif "symmap" in state: - self.symmap = state['symmap'] + self.engines = load_engines(training=False) + for name, engine in self.engines.items(): + if self.dtype != torch.int8: + engine.to(self.device, dtype=self.dtype if not self.amp else torch.float32) - if "module" in state: - state = state['module'] - - self.model.load_state_dict(state) - - if cfg.inference.backend == "local" and deepspeed_available and cfg.trainer.deepspeed.inferencing: - self.model = deepspeed.init_inference(model=self.model, mp_size=1, replace_with_kernel_inject=True, dtype=dtype if not amp else torch.float32).module - else: - engines = load_engines(training=False) - for name, engine in engines.items(): - self.model = engine.module - break - - if self.dtype != torch.int8: - self.model = self.model.to(self.device, dtype=self.dtype if not self.amp else torch.float32) - - self.model.eval() + self.engines.eval() if self.symmap is None: self.symmap = get_phone_symmap() @@ -159,6 +139,15 @@ class TTS(): wavs = [] sr = None + model_ar = None + model_nar = None + + for name, engine in self.engines.items(): + if "ar" in engine.hyper_config.capabilities: + model_ar = engine.module + if "nar" in engine.hyper_config.capabilities: + model_nar = engine.module + for line in lines: if out_path is None: out_path = f"./data/{cfg.start_time}.wav" @@ -172,7 +161,7 @@ class TTS(): lang = to_device(lang, self.device).to(torch.uint8) with torch.autocast("cuda", dtype=self.dtype, enabled=self.amp): - resps_list = self.model( + resps_list = model_ar( text_list=[phns], proms_list=[prom], lang_list=[lang], max_steps=max_ar_steps, max_resp_context=max_ar_context, sampling_temperature=ar_temp, sampling_min_temperature=min_ar_temp, @@ -183,8 +172,7 @@ class TTS(): sampling_mirostat_tau=mirostat_tau, sampling_mirostat_eta=mirostat_eta, ) - resps_list = [r.unsqueeze(-1) for r in resps_list] - resps_list = self.model( + resps_list = model_nar( text_list=[phns], proms_list=[prom], lang_list=[lang], resps_list=resps_list, max_levels=max_nar_levels, sampling_temperature=nar_temp, diff --git a/vall_e/models/__init__.py b/vall_e/models/__init__.py index 66d0564..d2f1bc6 100755 --- a/vall_e/models/__init__.py +++ b/vall_e/models/__init__.py @@ -1,37 +1,36 @@ -def get_model(cfg, training=True): - name = cfg.name +def get_model(config, training=True): + name = config.name - if not cfg.experimental: + if not config.experimental: from .ar_nar import AR_NAR model = AR_NAR( - n_text_tokens=cfg.text_tokens, - n_audio_tokens=cfg.audio_tokens, - d_model=cfg.dim, - n_heads=cfg.heads, - n_layers=cfg.layers, - n_experts=cfg.experts, + n_text_tokens=config.text_tokens, + n_audio_tokens=config.audio_tokens, + d_model=config.dim, + n_heads=config.heads, + n_layers=config.layers, + n_experts=config.experts, - p_dropout=cfg.dropout, + p_dropout=config.dropout, - l_padding = cfg.input_alignment, + l_padding = config.input_alignment, training = training, - config = cfg, + config = config, ) - model._cfg = cfg else: from .experimental import Model as Experimental model = Experimental( - n_text_tokens=cfg.text_tokens, - n_audio_tokens=cfg.audio_tokens, - - d_model=cfg.dim, - n_layers=cfg.layers, - n_heads=cfg.heads, - p_dropout=cfg.dropout, + n_text_tokens=config.text_tokens, + n_audio_tokens=config.audio_tokens, - config = cfg, + d_model=config.dim, + n_layers=config.layers, + n_heads=config.heads, + p_dropout=config.dropout, + + config = config, ) print(f"{name} ({next(model.parameters()).dtype}): {sum(p.numel() for p in model.parameters() if p.requires_grad)} parameters") diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index ecbcb3c..f3b99fc 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -16,7 +16,7 @@ class AR_NAR(Base): @property def causal(self): if hasattr(self, "config") and self.config: - return "ar" in self.capabilities + return "ar" in self.config.capabilities return True @property @@ -31,6 +31,8 @@ class AR_NAR(Base): @property def n_prom_levels(self) -> int: + if hasattr(self, "config") and self.config: + return self.config.prom_levels return cfg.model.prom_levels @property @@ -41,18 +43,26 @@ class AR_NAR(Base): @property def n_max_levels(self) -> int: + if hasattr(self, "config") and self.config: + return self.config.max_levels return cfg.model.max_levels @property def n_tasks(self) -> int: + if hasattr(self, "config") and self.config: + return self.config.tasks return cfg.model.tasks @property def n_langs(self) -> int: + if hasattr(self, "config") and self.config: + return self.config.langs return cfg.model.langs @property def n_tones(self) -> int: + if hasattr(self, "config") and self.config: + return self.config.tones return cfg.model.tones @property diff --git a/vall_e/models/arch/bitnet.py b/vall_e/models/arch/bitnet.py index f93021b..93d18f3 100644 --- a/vall_e/models/arch/bitnet.py +++ b/vall_e/models/arch/bitnet.py @@ -1,5 +1,7 @@ # https://github.com/kyegomez/BitNet from torch import Tensor, nn +from torch.utils.checkpoint import checkpoint + from bitnet.bit_transformer import Transformer as BitNetTransformerBlock, RMSNorm as BitNetRMSNorm # re-enable logging because zetascale fucking sucks diff --git a/vall_e/models/arch/mamba.py b/vall_e/models/arch/mamba.py index 078011d..4389b13 100644 --- a/vall_e/models/arch/mamba.py +++ b/vall_e/models/arch/mamba.py @@ -1,4 +1,6 @@ # https://github.com/state-spaces/mamba +from torch.utils.checkpoint import checkpoint + from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel, MambaConfig, MixerModel as MambaMixelModel, layer_norm_fn as MambaLayerNormFn, RMSNorm as MambaRMSNorm def MambaMixelModel_forward(self, input_ids=None, hidden_states=None, inference_params=None, **mixer_kwargs): diff --git a/vall_e/models/arch/mixtral.py b/vall_e/models/arch/mixtral.py index 568cd7c..e02dc13 100644 --- a/vall_e/models/arch/mixtral.py +++ b/vall_e/models/arch/mixtral.py @@ -1,12 +1,13 @@ # https://github.com/huggingface/transformers/blob/main/src/transformers/models/mixtral/modeling_mixtral.py import torch +import torch.nn.functional as F from transformers import MixtralModel, MixtralConfig from transformers.models.mixtral.modeling_mixtral import load_balancing_loss_func, MixtralSparseMoeBlock # This is required because batch sizes > 1 throws errors -def Fixed_MixtralSparseMoeBlock_forward(self, hidden_states: torch.Tensor) -> torch.Tensor: +def MixtralSparseMoeBlock_forward(self, hidden_states: torch.Tensor) -> torch.Tensor: """ """ batch_size, sequence_length, hidden_dim = hidden_states.shape hidden_states = hidden_states.reshape(-1, hidden_dim) # was view() @@ -41,5 +42,4 @@ def Fixed_MixtralSparseMoeBlock_forward(self, hidden_states: torch.Tensor) -> to final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) return final_hidden_states, router_logits -Original_MixtralSparseMoeBlock_forward = MixtralSparseMoeBlock.forward -MixtralSparseMoeBlock.forward = Fixed_MixtralSparseMoeBlock_forward \ No newline at end of file +MixtralSparseMoeBlock.forward = MixtralSparseMoeBlock_forward \ No newline at end of file diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 0bc1445..2511735 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -207,9 +207,9 @@ class Base(nn.Module): return -100 def loss_factor(self, k): - if self.hyper_config is None: + if self.config is None: return 1.0 - return self.hyper_config.loss_factors[k] if k in self.hyper_config.loss_factors else 1.0 + return self.config.loss_factors[k] if k in self.config.loss_factors else 1.0 def __init__( self, @@ -231,8 +231,8 @@ class Base(nn.Module): ): super().__init__() self.training = training - self.hyper_config = config - self.gradient_checkpointing = self.hyper_config.gradient_checkpointing if self.hyper_config is not None else True + self.config = config + self.gradient_checkpointing = self.config.gradient_checkpointing if self.config is not None else True self.n_text_tokens = n_text_tokens self.n_audio_tokens = n_audio_tokens @@ -246,7 +246,7 @@ class Base(nn.Module): # +1 to include the stop token n_prom_tokens = n_audio_tokens - n_resp_tokens = n_audio_tokens + (1 if self.causal else 0) # AR requires a stop token to... know when to stop + n_resp_tokens = n_audio_tokens + 1 # (1 if self.causal else 0) interoperability self.text_emb = Embedding(n_text_tokens, d_model) self.langs_emb = None @@ -263,13 +263,13 @@ class Base(nn.Module): self.proms_emb = AudioEmbedding( [n_prom_tokens] * self.n_prom_levels, d_model, levels=self.n_prom_levels if self.version > 3 else None, - sums=self.hyper_config.audio_embedding_sums if self.hyper_config is not None else True, + sums=self.config.audio_embedding_sums if self.config is not None else True, ) # [1024 + STOP] + [1024] * 8 self.resps_emb = AudioEmbedding( [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1), d_model, levels=self.n_resp_levels if self.version > 3 else None, - sums=self.hyper_config.audio_embedding_sums if self.hyper_config is not None else True + sums=self.config.audio_embedding_sums if self.config is not None else True ) # useless since I actually removed using these with the input processing overhaul... @@ -290,20 +290,20 @@ class Base(nn.Module): self.sep = nn.Parameter(torch.randn(d_model)) # ick, there has to be a better way - hf_attention = self.hyper_config.attention if self.hyper_config is not None else None + hf_attention = self.config.attention if self.config is not None else None - if self.hyper_config.attention == "auto": + if self.config.attention == "auto": if "flash" in AVAILABLE_ATTENTIONS: - self.hyper_config.attention = "flash" + self.config.attention = "flash" elif "xformers" in AVAILABLE_ATTENTIONS: - self.hyper_config.attention = "xformers" + self.config.attention = "xformers" else: - self.hyper_config.attention = "mem_efficient" + self.config.attention = "mem_efficient" - if self.hyper_config.attention in ["xformers", "mem_efficient", "math", "flash"]: + if self.config.attention in ["xformers", "mem_efficient", "math", "flash"]: hf_attention = None - if self.hyper_config.attention not in AVAILABLE_ATTENTIONS: - raise ValueError(f"Requesting attention `{self.hyper_config.attention}` but is not available. Currently available: {AVAILABLE_ATTENTIONS}") + if self.config.attention not in AVAILABLE_ATTENTIONS: + raise ValueError(f"Requesting attention `{self.config.attention}` but is not available. Currently available: {AVAILABLE_ATTENTIONS}") if self.arch_type == "transformer": @@ -326,7 +326,7 @@ class Base(nn.Module): num_hidden_layers=n_layers, num_attention_heads=n_heads, attention_dropout=p_dropout if training else 0.0, - num_key_value_heads=self.hyper_config.kv_heads if self.hyper_config.kv_heads > 0 else n_heads, + num_key_value_heads=self.config.kv_heads if self.config.kv_heads > 0 else n_heads, hidden_act="gelu", is_encoder_decoder=False, is_decoder=True, @@ -342,7 +342,7 @@ class Base(nn.Module): num_hidden_layers=n_layers, num_attention_heads=n_heads, attention_dropout=p_dropout if training else 0.0, - num_key_value_heads=self.hyper_config.kv_heads if self.hyper_config.kv_heads > 0 else n_heads, + num_key_value_heads=self.config.kv_heads if self.config.kv_heads > 0 else n_heads, sliding_window=75 * 12, # 12 second context window output_router_logits=training, hidden_act="gelu", @@ -492,8 +492,8 @@ class Base(nn.Module): else: raise RuntimeError(f'Unknown arch specified: {self.arch_type}') - if self.hyper_config.attention in ["xformers", "auto", "mem_efficient", "math", "flash"]: - self.model = ml.replace_attention( self.model, klass=LlamaAttention, target=LlamaAttention_Base, mode=self.hyper_config.attention ) + if self.config.attention in ["xformers", "auto", "mem_efficient", "math", "flash"]: + self.model = ml.replace_attention( self.model, klass=LlamaAttention, target=LlamaAttention_Base, mode=self.config.attention ) self.classifier = nn.Linear(d_model, n_resp_tokens) @@ -691,7 +691,7 @@ class Base(nn.Module): quant_levels: Tensor | None = None, ): # old, "naive" way, no loss factoring - if not self.hyper_config.loss_factors: + if not self.config.loss_factors: target_list = [] for batch_index, batch in enumerate(inputs): target = [] diff --git a/vall_e/webui.py b/vall_e/webui.py index 0e1a7c2..9a9071a 100644 --- a/vall_e/webui.py +++ b/vall_e/webui.py @@ -54,13 +54,12 @@ def init_tts(restart=False): parser = argparse.ArgumentParser(allow_abbrev=False) parser.add_argument("--yaml", type=Path, default=os.environ.get('VALLE_YAML', None)) # os environ so it can be specified in a HuggingFace Space too - parser.add_argument("--model-ckpt", type=Path, default=None) parser.add_argument("--device", type=str, default="cuda") parser.add_argument("--amp", action="store_true") parser.add_argument("--dtype", type=str, default="auto") args, unknown = parser.parse_known_args() - tts = TTS( config=args.yaml, model_ckpt=args.model_ckpt, device=args.device, dtype=args.dtype if args.dtype != "auto" else None, amp=args.amp ) + tts = TTS( config=args.yaml, device=args.device, dtype=args.dtype if args.dtype != "auto" else None, amp=args.amp ) return tts @gradio_wrapper(inputs=layout["inference"]["inputs"].keys())