diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index f98bd4f..3f7c139 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -317,7 +317,7 @@ class AR_NAR(Base): def example_usage(): - #cfg.trainer.backend = "local" + cfg.trainer.backend = "local" cfg.hyperparameters.gradient_accumulation_steps = 1 if cfg.audio_backend == "dac": cfg.sample_rate = 44_000 @@ -334,7 +334,11 @@ def example_usage(): import re device = "cuda" - x8 = partial(repeat, pattern="t -> t l", l=cfg.model.prom_levels) + + # mamba seems to ONLY be used as an AR (any NAR attempts lobotomizes it) + if "mamba" in cfg.model.arch_type: + cfg.model.prom_levels = 1 + cfg.model.resp_levels = 1 def tokenize(content): return torch.tensor( cfg.tokenizer.encode(content) ) @@ -368,7 +372,7 @@ def example_usage(): 'n_tokens': 1024, 'd_model': 1024, # 256, # 1024, # 1536 'n_heads': 16, # 4, # 16, # 24 - 'n_layers': 8, # 32 + 'n_layers': 12, # 32 'n_experts': 1, 'p_dropout': 0.1, @@ -386,7 +390,7 @@ def example_usage(): """ model = AR_NAR(**kwargs).to(device) - steps = 50 + steps = 250 optimizer = cfg.hyperparameters.optimizer.lower() if cfg.cfg_path is not None else "prodigy" scheduler = cfg.hyperparameters.scheduler.lower() if cfg.cfg_path is not None else "" @@ -459,11 +463,12 @@ def example_usage(): engine.eval() resps_list = engine(text_list, proms_list, max_steps=steps, sampling_temperature=0.95 ) - resps_list = [r.unsqueeze(-1) for r in resps_list] - resps_list = engine( text_list, proms_list, resps_list=resps_list, sampling_temperature=0.2 ) + if cfg.model.max_levels > 1: + resps_list = [r.unsqueeze(-1) for r in resps_list] + resps_list = engine( text_list, proms_list, resps_list=resps_list, sampling_temperature=0.2 ) for i, o in enumerate(resps_list): - _ = decode_to_file(o, f"data/{cfg.model.arch_type}.{cfg.audio_backend}.{i}.{name}.wav", device=device) + _ = decode_to_file(o.to(dtype=torch.int32), f"data/{cfg.model.arch_type}.{cfg.audio_backend}.{i}.{name}.wav", device=device) unload_model() diff --git a/vall_e/models/base.py b/vall_e/models/base.py index f1dd72a..8e59982 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -155,6 +155,40 @@ except Exception as e: print("Error importing `mixtral` arch:", e) +try: + from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel, MambaConfig, MixerModel as MambaMixelModel, layer_norm_fn as MambaLayerNormFn, RMSNorm as MambaRMSNorm + + def MambaMixelModel_forward(self, input_ids=None, hidden_states=None, inference_params=None, **mixer_kwargs): + if hidden_states is None: + hidden_states = self.embedding(input_ids) + residual = None + for layer in self.layers: + if self.gradient_checkpointing and hidden_states.requires_grad: + hidden_states, residual = checkpoint( layer, hidden_states, residual, inference_params=inference_params, use_reentrant=False ) + else: + hidden_states, residual = layer( hidden_states, residual, inference_params=inference_params ) + if not self.fused_add_norm: + residual = (hidden_states + residual) if residual is not None else hidden_states + hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype)) + else: + # Set prenorm=False here since we don't need the residual + hidden_states = MambaLayerNormFn( + hidden_states, + self.norm_f.weight, + self.norm_f.bias, + eps=self.norm_f.eps, + residual=residual, + prenorm=False, + residual_in_fp32=self.residual_in_fp32, + is_rms_norm=isinstance(self.norm_f, MambaRMSNorm) + ) + return hidden_states + + MambaMixelModel.forward = MambaMixelModel_forward +except Exception as e: + print("Error importing `mixtral` arch:", e) + + AVAILABLE_ATTENTIONS = ["mem_efficient", "math"] try: @@ -686,6 +720,21 @@ class Base(nn.Module): ff_mult=4, gradient_checkpointing=self.gradient_checkpointing, ) + elif self.arch_type in ["mamba","mamba2"]: + self.model = MambaMixelModel( + vocab_size=n_resp_tokens, + d_model=d_model, + n_layer=n_layers*2, + d_intermediate=0, + ssm_cfg={"layer": "Mamba2", "chunk_size":64} if self.arch_type == "mamba2" else {}, + rms_norm=True, + fused_add_norm=True, + residual_in_fp32=True, + #attn_layer_idx=attn_layer_idx, + #attn_cfg=attn_cfg, + #initializer_cfg=initializer_cfg, + ) + self.model.gradient_checkpointing = self.gradient_checkpointing else: raise RuntimeError(f'Unknown arch specified: {self.arch_type}') @@ -804,7 +853,8 @@ class Base(nn.Module): x = out.last_hidden_state if state is not None: state = out.past_key_values - + elif self.arch_type in ["mamba","mamba2"]: + x = self.model( hidden_states=x ) elif self.arch_type == "bitnet": x = self.model(x) diff --git a/vall_e/models/experimental.py b/vall_e/models/experimental.py index 9f8e37f..052c3d9 100644 --- a/vall_e/models/experimental.py +++ b/vall_e/models/experimental.py @@ -158,6 +158,8 @@ class Model(LlmArchClass): d_model=d_model, n_layer=n_layers*2, ssm_cfg={"layer": "Mamba2", "chunk_size":64} if SELECTED_ARCH == "mamba2" else {}, + fused_add_norm=True, + residual_in_fp32=True, )) self.backbone.gradient_checkpointing = gradient_checkpointing