haha...
This commit is contained in:
parent
b7bd885651
commit
1547de5020
|
@ -338,7 +338,7 @@ class DeepSpeed:
|
||||||
use_compression_training: bool = False
|
use_compression_training: bool = False
|
||||||
compression_bits: int = 8
|
compression_bits: int = 8
|
||||||
inferencing: bool = False
|
inferencing: bool = False
|
||||||
amp: bool = True
|
amp: bool = False
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def ds_cfg(self):
|
def ds_cfg(self):
|
||||||
|
|
|
@ -481,6 +481,7 @@ class Base(nn.Module):
|
||||||
is_encoder_decoder=False,
|
is_encoder_decoder=False,
|
||||||
is_decoder=True,
|
is_decoder=True,
|
||||||
attn_implementation=attention,
|
attn_implementation=attention,
|
||||||
|
gradient_checkpointing=self.activation_checkpointing,
|
||||||
))
|
))
|
||||||
else:
|
else:
|
||||||
self.model = MixtralModel(MixtralConfig(
|
self.model = MixtralModel(MixtralConfig(
|
||||||
|
@ -500,6 +501,7 @@ class Base(nn.Module):
|
||||||
num_local_experts=n_experts,
|
num_local_experts=n_experts,
|
||||||
num_experts_per_tok=min(2, n_experts),
|
num_experts_per_tok=min(2, n_experts),
|
||||||
attn_implementation=attention,
|
attn_implementation=attention,
|
||||||
|
gradient_checkpointing=self.activation_checkpointing,
|
||||||
))
|
))
|
||||||
elif self.arch_type == "llama":
|
elif self.arch_type == "llama":
|
||||||
if n_experts <= 1:
|
if n_experts <= 1:
|
||||||
|
@ -517,6 +519,7 @@ class Base(nn.Module):
|
||||||
is_encoder_decoder=False,
|
is_encoder_decoder=False,
|
||||||
is_decoder=True,
|
is_decoder=True,
|
||||||
attn_implementation=attention,
|
attn_implementation=attention,
|
||||||
|
gradient_checkpointing=self.activation_checkpointing,
|
||||||
))
|
))
|
||||||
else:
|
else:
|
||||||
self.model = MixtralModel(MixtralConfig(
|
self.model = MixtralModel(MixtralConfig(
|
||||||
|
@ -536,7 +539,15 @@ class Base(nn.Module):
|
||||||
num_local_experts=n_experts,
|
num_local_experts=n_experts,
|
||||||
num_experts_per_tok=min(2, n_experts),
|
num_experts_per_tok=min(2, n_experts),
|
||||||
attn_implementation=attention,
|
attn_implementation=attention,
|
||||||
|
gradient_checkpointing=self.activation_checkpointing,
|
||||||
))
|
))
|
||||||
|
|
||||||
|
print("Checkpointing:", self.activation_checkpointing, self.model.gradient_checkpointing)
|
||||||
|
if self.activation_checkpointing and not self.model.gradient_checkpointing:
|
||||||
|
self.model.gradient_checkpointing_enable()
|
||||||
|
|
||||||
|
if training:
|
||||||
|
self.model.training = True
|
||||||
elif self.arch_type == "retnet":
|
elif self.arch_type == "retnet":
|
||||||
kwargs = dict(
|
kwargs = dict(
|
||||||
vocab_size=n_resp_tokens,
|
vocab_size=n_resp_tokens,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user