This commit is contained in:
mrq 2024-05-09 23:15:52 -05:00
parent b7bd885651
commit 1547de5020
2 changed files with 12 additions and 1 deletions

View File

@ -338,7 +338,7 @@ class DeepSpeed:
use_compression_training: bool = False
compression_bits: int = 8
inferencing: bool = False
amp: bool = True
amp: bool = False
@cached_property
def ds_cfg(self):

View File

@ -481,6 +481,7 @@ class Base(nn.Module):
is_encoder_decoder=False,
is_decoder=True,
attn_implementation=attention,
gradient_checkpointing=self.activation_checkpointing,
))
else:
self.model = MixtralModel(MixtralConfig(
@ -500,6 +501,7 @@ class Base(nn.Module):
num_local_experts=n_experts,
num_experts_per_tok=min(2, n_experts),
attn_implementation=attention,
gradient_checkpointing=self.activation_checkpointing,
))
elif self.arch_type == "llama":
if n_experts <= 1:
@ -517,6 +519,7 @@ class Base(nn.Module):
is_encoder_decoder=False,
is_decoder=True,
attn_implementation=attention,
gradient_checkpointing=self.activation_checkpointing,
))
else:
self.model = MixtralModel(MixtralConfig(
@ -536,7 +539,15 @@ class Base(nn.Module):
num_local_experts=n_experts,
num_experts_per_tok=min(2, n_experts),
attn_implementation=attention,
gradient_checkpointing=self.activation_checkpointing,
))
print("Checkpointing:", self.activation_checkpointing, self.model.gradient_checkpointing)
if self.activation_checkpointing and not self.model.gradient_checkpointing:
self.model.gradient_checkpointing_enable()
if training:
self.model.training = True
elif self.arch_type == "retnet":
kwargs = dict(
vocab_size=n_resp_tokens,