residual_in_fp32=False for mamba arch backends because it breaks the classifier (output projection / lm head / what-have-you) under AMP
This commit is contained in:
parent
ccb14c06ef
commit
d343bde09b
|
@ -1024,7 +1024,12 @@ def create_datasets():
|
|||
def create_train_val_dataloader():
|
||||
train_dataset, val_dataset = create_datasets()
|
||||
|
||||
subtrain_dataset = copy.deepcopy(train_dataset)
|
||||
# it'll cry about trying to pickle a torch._C_generator or something
|
||||
try:
|
||||
subtrain_dataset = copy.deepcopy(train_dataset)
|
||||
except Exception as e:
|
||||
subtrain_dataset = Dataset( training=True )
|
||||
|
||||
if subtrain_dataset.sampler_type == "path":
|
||||
subtrain_dataset.head_(cfg.evaluation.size)
|
||||
|
||||
|
|
|
@ -587,10 +587,10 @@ class Base(nn.Module):
|
|||
d_model=d_model,
|
||||
n_layer=n_layers,
|
||||
d_intermediate=d_model*4,
|
||||
ssm_cfg={"layer": "Mamba2"} if self.arch_type == "mamba2" else {},
|
||||
ssm_cfg={"layer": "Mamba2", "use_mem_eff_path": False} if self.arch_type == "mamba2" else {},
|
||||
rms_norm=True,
|
||||
fused_add_norm=True,
|
||||
residual_in_fp32=True,
|
||||
residual_in_fp32=False,
|
||||
#attn_layer_idx=attn_layer_idx,
|
||||
#attn_cfg=attn_cfg,
|
||||
#initializer_cfg=initializer_cfg,
|
||||
|
@ -606,6 +606,7 @@ class Base(nn.Module):
|
|||
is_encoder_decoder=False,
|
||||
is_decoder=True,
|
||||
use_triton_kernels=False, # the entire reason is to NOT use triton (because V100s hate it)
|
||||
residual_in_fp32=False, # breaks for AMP inference
|
||||
))
|
||||
if self.gradient_checkpointing and not self.model.gradient_checkpointing:
|
||||
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict(
|
||||
|
|
Loading…
Reference in New Issue
Block a user