mamba updated to fix that pesky NaN error during training
This commit is contained in:
parent
bcf3910a17
commit
26da24fd8d
|
@ -430,7 +430,7 @@ def example_usage():
|
|||
"""
|
||||
|
||||
model = AR_NAR(**kwargs).to(device)
|
||||
steps = 200
|
||||
steps = 50 if cfg.model.arch_type in ["mamba","mamba2"] else 200
|
||||
|
||||
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy"
|
||||
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else ""
|
||||
|
|
|
@ -583,7 +583,7 @@ class Base(nn.Module):
|
|||
d_model=d_model,
|
||||
n_layer=n_layers*2,
|
||||
d_intermediate=0,
|
||||
ssm_cfg={"layer": "Mamba2", "chunk_size":64} if self.arch_type == "mamba2" else {},
|
||||
ssm_cfg={"layer": "Mamba2"} if self.arch_type == "mamba2" else {},
|
||||
rms_norm=True,
|
||||
fused_add_norm=True,
|
||||
residual_in_fp32=True,
|
||||
|
|
Loading…
Reference in New Issue
Block a user