diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index ffa1f18..e2ae82c 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -430,7 +430,7 @@ def example_usage(): """ model = AR_NAR(**kwargs).to(device) - steps = 50 if cfg.model.arch_type in ["mamba","mamba2"] else 200 + steps = 200 if cfg.model.arch_type in ["mamba","mamba2"] else 200 optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy" scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else "" diff --git a/vall_e/models/arch/__init__.py b/vall_e/models/arch/__init__.py index ee6dc41..c0843f4 100755 --- a/vall_e/models/arch/__init__.py +++ b/vall_e/models/arch/__init__.py @@ -49,7 +49,7 @@ except Exception as e: print("Error importing `mixtral` arch:", e) try: - from .mamba import MambaMixelModel, MambaLMHeadModel + from .mamba import MambaMixelModel, MambaLMHeadModel, MambaConfig AVAILABLE_ARCHES.append("mamba") AVAILABLE_ARCHES.append("mamba2") except Exception as e: diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 19241e2..b8e5b92 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -581,8 +581,8 @@ class Base(nn.Module): self.model = MambaMixelModel( vocab_size=n_resp_tokens, d_model=d_model, - n_layer=n_layers*2, - d_intermediate=0, + n_layer=n_layers, + d_intermediate=d_model*4, ssm_cfg={"layer": "Mamba2"} if self.arch_type == "mamba2" else {}, rms_norm=True, fused_add_norm=True, @@ -1092,7 +1092,8 @@ class Base(nn.Module): logits = [ logit.to(device="cpu", dtype=logit.dtype if logit.dtype != torch.float16 else torch.float32) for logit in logits ] # perform repetition penalizing - logits = [ reptition_penalize(logit, previous=resps[:, -1], factor=repetition_penalty, decay=repetition_penalty_decay) for logit, resps in zip( logits, resps_list ) ] + if "len" not in self.capabilities: + logits = [ reptition_penalize(logit, previous=resps[:, -1], factor=repetition_penalty, decay=repetition_penalty_decay) for logit, resps in zip( logits, resps_list ) ] # argmax instead if temperature <= 0.0: diff --git a/vall_e/models/experimental.py b/vall_e/models/experimental.py index 0e8e68a..6cb5aba 100644 --- a/vall_e/models/experimental.py +++ b/vall_e/models/experimental.py @@ -106,8 +106,10 @@ class Model(LlmArchClass): super().__init__(config=MambaConfig( vocab_size=vocab_size, d_model=d_model, - n_layer=n_layers*2, - ssm_cfg={"layer": "Mamba2", "chunk_size":64} if cfg.model.arch_type == "mamba2" else {}, + n_layer=n_layers, + d_intermediate=d_model*4, + ssm_cfg={"layer": "Mamba2"} if cfg.model.arch_type == "mamba2" else {}, + rms_norm=True, fused_add_norm=True, residual_in_fp32=True, ))