haha...
This commit is contained in:
parent
b7bd885651
commit
1547de5020
|
@ -338,7 +338,7 @@ class DeepSpeed:
|
|||
use_compression_training: bool = False
|
||||
compression_bits: int = 8
|
||||
inferencing: bool = False
|
||||
amp: bool = True
|
||||
amp: bool = False
|
||||
|
||||
@cached_property
|
||||
def ds_cfg(self):
|
||||
|
|
|
@ -481,6 +481,7 @@ class Base(nn.Module):
|
|||
is_encoder_decoder=False,
|
||||
is_decoder=True,
|
||||
attn_implementation=attention,
|
||||
gradient_checkpointing=self.activation_checkpointing,
|
||||
))
|
||||
else:
|
||||
self.model = MixtralModel(MixtralConfig(
|
||||
|
@ -500,6 +501,7 @@ class Base(nn.Module):
|
|||
num_local_experts=n_experts,
|
||||
num_experts_per_tok=min(2, n_experts),
|
||||
attn_implementation=attention,
|
||||
gradient_checkpointing=self.activation_checkpointing,
|
||||
))
|
||||
elif self.arch_type == "llama":
|
||||
if n_experts <= 1:
|
||||
|
@ -517,6 +519,7 @@ class Base(nn.Module):
|
|||
is_encoder_decoder=False,
|
||||
is_decoder=True,
|
||||
attn_implementation=attention,
|
||||
gradient_checkpointing=self.activation_checkpointing,
|
||||
))
|
||||
else:
|
||||
self.model = MixtralModel(MixtralConfig(
|
||||
|
@ -536,7 +539,15 @@ class Base(nn.Module):
|
|||
num_local_experts=n_experts,
|
||||
num_experts_per_tok=min(2, n_experts),
|
||||
attn_implementation=attention,
|
||||
gradient_checkpointing=self.activation_checkpointing,
|
||||
))
|
||||
|
||||
print("Checkpointing:", self.activation_checkpointing, self.model.gradient_checkpointing)
|
||||
if self.activation_checkpointing and not self.model.gradient_checkpointing:
|
||||
self.model.gradient_checkpointing_enable()
|
||||
|
||||
if training:
|
||||
self.model.training = True
|
||||
elif self.arch_type == "retnet":
|
||||
kwargs = dict(
|
||||
vocab_size=n_resp_tokens,
|
||||
|
|
Loading…
Reference in New Issue
Block a user