residual_in_fp32=False for mamba arch backends because it breaks the classifier (output projection / lm head / what-have-you) under AMP
This commit is contained in:
parent
ccb14c06ef
commit
d343bde09b
|
@ -1024,7 +1024,12 @@ def create_datasets():
|
||||||
def create_train_val_dataloader():
|
def create_train_val_dataloader():
|
||||||
train_dataset, val_dataset = create_datasets()
|
train_dataset, val_dataset = create_datasets()
|
||||||
|
|
||||||
|
# it'll cry about trying to pickle a torch._C_generator or something
|
||||||
|
try:
|
||||||
subtrain_dataset = copy.deepcopy(train_dataset)
|
subtrain_dataset = copy.deepcopy(train_dataset)
|
||||||
|
except Exception as e:
|
||||||
|
subtrain_dataset = Dataset( training=True )
|
||||||
|
|
||||||
if subtrain_dataset.sampler_type == "path":
|
if subtrain_dataset.sampler_type == "path":
|
||||||
subtrain_dataset.head_(cfg.evaluation.size)
|
subtrain_dataset.head_(cfg.evaluation.size)
|
||||||
|
|
||||||
|
|
|
@ -587,10 +587,10 @@ class Base(nn.Module):
|
||||||
d_model=d_model,
|
d_model=d_model,
|
||||||
n_layer=n_layers,
|
n_layer=n_layers,
|
||||||
d_intermediate=d_model*4,
|
d_intermediate=d_model*4,
|
||||||
ssm_cfg={"layer": "Mamba2"} if self.arch_type == "mamba2" else {},
|
ssm_cfg={"layer": "Mamba2", "use_mem_eff_path": False} if self.arch_type == "mamba2" else {},
|
||||||
rms_norm=True,
|
rms_norm=True,
|
||||||
fused_add_norm=True,
|
fused_add_norm=True,
|
||||||
residual_in_fp32=True,
|
residual_in_fp32=False,
|
||||||
#attn_layer_idx=attn_layer_idx,
|
#attn_layer_idx=attn_layer_idx,
|
||||||
#attn_cfg=attn_cfg,
|
#attn_cfg=attn_cfg,
|
||||||
#initializer_cfg=initializer_cfg,
|
#initializer_cfg=initializer_cfg,
|
||||||
|
@ -606,6 +606,7 @@ class Base(nn.Module):
|
||||||
is_encoder_decoder=False,
|
is_encoder_decoder=False,
|
||||||
is_decoder=True,
|
is_decoder=True,
|
||||||
use_triton_kernels=False, # the entire reason is to NOT use triton (because V100s hate it)
|
use_triton_kernels=False, # the entire reason is to NOT use triton (because V100s hate it)
|
||||||
|
residual_in_fp32=False, # breaks for AMP inference
|
||||||
))
|
))
|
||||||
if self.gradient_checkpointing and not self.model.gradient_checkpointing:
|
if self.gradient_checkpointing and not self.model.gradient_checkpointing:
|
||||||
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict(
|
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict(
|
||||||
|
|
Loading…
Reference in New Issue
Block a user