From d343bde09b6874205358a26cf0e88d5cc0af44e5 Mon Sep 17 00:00:00 2001 From: mrq Date: Sat, 15 Jun 2024 12:08:03 -0500 Subject: [PATCH] residual_in_fp32=False for mamba arch backends because it breaks the classifier (output projection / lm head / what-have-you) under AMP --- vall_e/data.py | 7 ++++++- vall_e/models/base.py | 5 +++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/vall_e/data.py b/vall_e/data.py index 3fd0252..9f414df 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -1024,7 +1024,12 @@ def create_datasets(): def create_train_val_dataloader(): train_dataset, val_dataset = create_datasets() - subtrain_dataset = copy.deepcopy(train_dataset) + # it'll cry about trying to pickle a torch._C_generator or something + try: + subtrain_dataset = copy.deepcopy(train_dataset) + except Exception as e: + subtrain_dataset = Dataset( training=True ) + if subtrain_dataset.sampler_type == "path": subtrain_dataset.head_(cfg.evaluation.size) diff --git a/vall_e/models/base.py b/vall_e/models/base.py index e3d85a8..b8671a0 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -587,10 +587,10 @@ class Base(nn.Module): d_model=d_model, n_layer=n_layers, d_intermediate=d_model*4, - ssm_cfg={"layer": "Mamba2"} if self.arch_type == "mamba2" else {}, + ssm_cfg={"layer": "Mamba2", "use_mem_eff_path": False} if self.arch_type == "mamba2" else {}, rms_norm=True, fused_add_norm=True, - residual_in_fp32=True, + residual_in_fp32=False, #attn_layer_idx=attn_layer_idx, #attn_cfg=attn_cfg, #initializer_cfg=initializer_cfg, @@ -606,6 +606,7 @@ class Base(nn.Module): is_encoder_decoder=False, is_decoder=True, use_triton_kernels=False, # the entire reason is to NOT use triton (because V100s hate it) + residual_in_fp32=False, # breaks for AMP inference )) if self.gradient_checkpointing and not self.model.gradient_checkpointing: self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict(