resolve deprecation warning that doesn't show on my old training rig but does on my new one
This commit is contained in:
parent
1547de5020
commit
2109712e5b
|
@ -365,7 +365,7 @@ def example_usage():
|
|||
'n_tokens': 1024,
|
||||
'd_model': 1024, # 256, # 1024, # 1536
|
||||
'n_heads': 16, # 4, # 16, # 24
|
||||
'n_layers': 4, # 32
|
||||
'n_layers': 12, # 32
|
||||
'n_experts': 1,
|
||||
|
||||
'l_padding': 8 if cfg.optimizations.fp8 else 0,
|
||||
|
|
|
@ -481,7 +481,7 @@ class Base(nn.Module):
|
|||
is_encoder_decoder=False,
|
||||
is_decoder=True,
|
||||
attn_implementation=attention,
|
||||
gradient_checkpointing=self.activation_checkpointing,
|
||||
#gradient_checkpointing=self.activation_checkpointing,
|
||||
))
|
||||
else:
|
||||
self.model = MixtralModel(MixtralConfig(
|
||||
|
@ -501,8 +501,14 @@ class Base(nn.Module):
|
|||
num_local_experts=n_experts,
|
||||
num_experts_per_tok=min(2, n_experts),
|
||||
attn_implementation=attention,
|
||||
gradient_checkpointing=self.activation_checkpointing,
|
||||
#gradient_checkpointing=self.activation_checkpointing,
|
||||
))
|
||||
|
||||
if self.activation_checkpointing and not self.model.gradient_checkpointing:
|
||||
self.model.gradient_checkpointing_enable()
|
||||
|
||||
if training:
|
||||
self.model.training = True
|
||||
elif self.arch_type == "llama":
|
||||
if n_experts <= 1:
|
||||
self.model = LlamaModel(LlamaConfig(
|
||||
|
@ -519,7 +525,7 @@ class Base(nn.Module):
|
|||
is_encoder_decoder=False,
|
||||
is_decoder=True,
|
||||
attn_implementation=attention,
|
||||
gradient_checkpointing=self.activation_checkpointing,
|
||||
#gradient_checkpointing=self.activation_checkpointing,
|
||||
))
|
||||
else:
|
||||
self.model = MixtralModel(MixtralConfig(
|
||||
|
@ -539,10 +545,9 @@ class Base(nn.Module):
|
|||
num_local_experts=n_experts,
|
||||
num_experts_per_tok=min(2, n_experts),
|
||||
attn_implementation=attention,
|
||||
gradient_checkpointing=self.activation_checkpointing,
|
||||
#gradient_checkpointing=self.activation_checkpointing,
|
||||
))
|
||||
|
||||
print("Checkpointing:", self.activation_checkpointing, self.model.gradient_checkpointing)
|
||||
if self.activation_checkpointing and not self.model.gradient_checkpointing:
|
||||
self.model.gradient_checkpointing_enable()
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user