From 26da24fd8ded4567fd0330399259c7c41b210dfb Mon Sep 17 00:00:00 2001 From: mrq Date: Thu, 13 Jun 2024 12:38:33 -0500 Subject: [PATCH] mamba updated to fix that pesky NaN error during training --- vall_e/models/ar_nar.py | 2 +- vall_e/models/base.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 23bfd80..ffa1f18 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -430,7 +430,7 @@ def example_usage(): """ model = AR_NAR(**kwargs).to(device) - steps = 200 + steps = 50 if cfg.model.arch_type in ["mamba","mamba2"] else 200 optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy" scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else "" diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 9f381d5..19241e2 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -583,7 +583,7 @@ class Base(nn.Module): d_model=d_model, n_layer=n_layers*2, d_intermediate=0, - ssm_cfg={"layer": "Mamba2", "chunk_size":64} if self.arch_type == "mamba2" else {}, + ssm_cfg={"layer": "Mamba2"} if self.arch_type == "mamba2" else {}, rms_norm=True, fused_add_norm=True, residual_in_fp32=True,