mamba updated to fix that pesky NaN error during training
This commit is contained in:
parent
bcf3910a17
commit
26da24fd8d
|
@ -430,7 +430,7 @@ def example_usage():
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model = AR_NAR(**kwargs).to(device)
|
model = AR_NAR(**kwargs).to(device)
|
||||||
steps = 200
|
steps = 50 if cfg.model.arch_type in ["mamba","mamba2"] else 200
|
||||||
|
|
||||||
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy"
|
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy"
|
||||||
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else ""
|
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else ""
|
||||||
|
|
|
@ -583,7 +583,7 @@ class Base(nn.Module):
|
||||||
d_model=d_model,
|
d_model=d_model,
|
||||||
n_layer=n_layers*2,
|
n_layer=n_layers*2,
|
||||||
d_intermediate=0,
|
d_intermediate=0,
|
||||||
ssm_cfg={"layer": "Mamba2", "chunk_size":64} if self.arch_type == "mamba2" else {},
|
ssm_cfg={"layer": "Mamba2"} if self.arch_type == "mamba2" else {},
|
||||||
rms_norm=True,
|
rms_norm=True,
|
||||||
fused_add_norm=True,
|
fused_add_norm=True,
|
||||||
residual_in_fp32=True,
|
residual_in_fp32=True,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user