mamba updated to fix that pesky NaN error during training

This commit is contained in:
mrq 2024-06-13 12:38:33 -05:00
parent bcf3910a17
commit 26da24fd8d
2 changed files with 2 additions and 2 deletions

View File

@ -430,7 +430,7 @@ def example_usage():
"""
model = AR_NAR(**kwargs).to(device)
steps = 200
steps = 50 if cfg.model.arch_type in ["mamba","mamba2"] else 200
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy"
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else ""

View File

@ -583,7 +583,7 @@ class Base(nn.Module):
d_model=d_model,
n_layer=n_layers*2,
d_intermediate=0,
ssm_cfg={"layer": "Mamba2", "chunk_size":64} if self.arch_type == "mamba2" else {},
ssm_cfg={"layer": "Mamba2"} if self.arch_type == "mamba2" else {},
rms_norm=True,
fused_add_norm=True,
residual_in_fp32=True,