diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 23bfd80..ffa1f18 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -430,7 +430,7 @@ def example_usage(): """ model = AR_NAR(**kwargs).to(device) - steps = 200 + steps = 50 if cfg.model.arch_type in ["mamba","mamba2"] else 200 optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy" scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else "" diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 9f381d5..19241e2 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -583,7 +583,7 @@ class Base(nn.Module): d_model=d_model, n_layer=n_layers*2, d_intermediate=0, - ssm_cfg={"layer": "Mamba2", "chunk_size":64} if self.arch_type == "mamba2" else {}, + ssm_cfg={"layer": "Mamba2"} if self.arch_type == "mamba2" else {}, rms_norm=True, fused_add_norm=True, residual_in_fp32=True,