This commit is contained in:
mrq 2024-05-09 23:15:52 -05:00
parent b7bd885651
commit 1547de5020
2 changed files with 12 additions and 1 deletions

View File

@ -338,7 +338,7 @@ class DeepSpeed:
use_compression_training: bool = False use_compression_training: bool = False
compression_bits: int = 8 compression_bits: int = 8
inferencing: bool = False inferencing: bool = False
amp: bool = True amp: bool = False
@cached_property @cached_property
def ds_cfg(self): def ds_cfg(self):

View File

@ -481,6 +481,7 @@ class Base(nn.Module):
is_encoder_decoder=False, is_encoder_decoder=False,
is_decoder=True, is_decoder=True,
attn_implementation=attention, attn_implementation=attention,
gradient_checkpointing=self.activation_checkpointing,
)) ))
else: else:
self.model = MixtralModel(MixtralConfig( self.model = MixtralModel(MixtralConfig(
@ -500,6 +501,7 @@ class Base(nn.Module):
num_local_experts=n_experts, num_local_experts=n_experts,
num_experts_per_tok=min(2, n_experts), num_experts_per_tok=min(2, n_experts),
attn_implementation=attention, attn_implementation=attention,
gradient_checkpointing=self.activation_checkpointing,
)) ))
elif self.arch_type == "llama": elif self.arch_type == "llama":
if n_experts <= 1: if n_experts <= 1:
@ -517,6 +519,7 @@ class Base(nn.Module):
is_encoder_decoder=False, is_encoder_decoder=False,
is_decoder=True, is_decoder=True,
attn_implementation=attention, attn_implementation=attention,
gradient_checkpointing=self.activation_checkpointing,
)) ))
else: else:
self.model = MixtralModel(MixtralConfig( self.model = MixtralModel(MixtralConfig(
@ -536,7 +539,15 @@ class Base(nn.Module):
num_local_experts=n_experts, num_local_experts=n_experts,
num_experts_per_tok=min(2, n_experts), num_experts_per_tok=min(2, n_experts),
attn_implementation=attention, attn_implementation=attention,
gradient_checkpointing=self.activation_checkpointing,
)) ))
print("Checkpointing:", self.activation_checkpointing, self.model.gradient_checkpointing)
if self.activation_checkpointing and not self.model.gradient_checkpointing:
self.model.gradient_checkpointing_enable()
if training:
self.model.training = True
elif self.arch_type == "retnet": elif self.arch_type == "retnet":
kwargs = dict( kwargs = dict(
vocab_size=n_resp_tokens, vocab_size=n_resp_tokens,