residual_in_fp32=False for mamba arch backends because it breaks the classifier (output projection / lm head / what-have-you) under AMP

This commit is contained in:
mrq 2024-06-15 12:08:03 -05:00
parent ccb14c06ef
commit d343bde09b
2 changed files with 9 additions and 3 deletions

View File

@ -1024,7 +1024,12 @@ def create_datasets():
def create_train_val_dataloader():
train_dataset, val_dataset = create_datasets()
subtrain_dataset = copy.deepcopy(train_dataset)
# it'll cry about trying to pickle a torch._C_generator or something
try:
subtrain_dataset = copy.deepcopy(train_dataset)
except Exception as e:
subtrain_dataset = Dataset( training=True )
if subtrain_dataset.sampler_type == "path":
subtrain_dataset.head_(cfg.evaluation.size)

View File

@ -587,10 +587,10 @@ class Base(nn.Module):
d_model=d_model,
n_layer=n_layers,
d_intermediate=d_model*4,
ssm_cfg={"layer": "Mamba2"} if self.arch_type == "mamba2" else {},
ssm_cfg={"layer": "Mamba2", "use_mem_eff_path": False} if self.arch_type == "mamba2" else {},
rms_norm=True,
fused_add_norm=True,
residual_in_fp32=True,
residual_in_fp32=False,
#attn_layer_idx=attn_layer_idx,
#attn_cfg=attn_cfg,
#initializer_cfg=initializer_cfg,
@ -606,6 +606,7 @@ class Base(nn.Module):
is_encoder_decoder=False,
is_decoder=True,
use_triton_kernels=False, # the entire reason is to NOT use triton (because V100s hate it)
residual_in_fp32=False, # breaks for AMP inference
))
if self.gradient_checkpointing and not self.model.gradient_checkpointing:
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict(