resolve deprecation warning that doesn't show on my old training rig but does on my new one

This commit is contained in:
mrq 2024-05-09 23:25:44 -05:00
parent 1547de5020
commit 2109712e5b
2 changed files with 11 additions and 6 deletions

View File

@ -365,7 +365,7 @@ def example_usage():
'n_tokens': 1024,
'd_model': 1024, # 256, # 1024, # 1536
'n_heads': 16, # 4, # 16, # 24
'n_layers': 4, # 32
'n_layers': 12, # 32
'n_experts': 1,
'l_padding': 8 if cfg.optimizations.fp8 else 0,

View File

@ -481,7 +481,7 @@ class Base(nn.Module):
is_encoder_decoder=False,
is_decoder=True,
attn_implementation=attention,
gradient_checkpointing=self.activation_checkpointing,
#gradient_checkpointing=self.activation_checkpointing,
))
else:
self.model = MixtralModel(MixtralConfig(
@ -501,8 +501,14 @@ class Base(nn.Module):
num_local_experts=n_experts,
num_experts_per_tok=min(2, n_experts),
attn_implementation=attention,
gradient_checkpointing=self.activation_checkpointing,
#gradient_checkpointing=self.activation_checkpointing,
))
if self.activation_checkpointing and not self.model.gradient_checkpointing:
self.model.gradient_checkpointing_enable()
if training:
self.model.training = True
elif self.arch_type == "llama":
if n_experts <= 1:
self.model = LlamaModel(LlamaConfig(
@ -519,7 +525,7 @@ class Base(nn.Module):
is_encoder_decoder=False,
is_decoder=True,
attn_implementation=attention,
gradient_checkpointing=self.activation_checkpointing,
#gradient_checkpointing=self.activation_checkpointing,
))
else:
self.model = MixtralModel(MixtralConfig(
@ -539,10 +545,9 @@ class Base(nn.Module):
num_local_experts=n_experts,
num_experts_per_tok=min(2, n_experts),
attn_implementation=attention,
gradient_checkpointing=self.activation_checkpointing,
#gradient_checkpointing=self.activation_checkpointing,
))
print("Checkpointing:", self.activation_checkpointing, self.model.gradient_checkpointing)
if self.activation_checkpointing and not self.model.gradient_checkpointing:
self.model.gradient_checkpointing_enable()