diff --git a/vall_e/config.py b/vall_e/config.py index ab57e7f..29d9288 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -338,7 +338,7 @@ class DeepSpeed: use_compression_training: bool = False compression_bits: int = 8 inferencing: bool = False - amp: bool = True + amp: bool = False @cached_property def ds_cfg(self): diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 2cd0d04..52ca193 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -481,6 +481,7 @@ class Base(nn.Module): is_encoder_decoder=False, is_decoder=True, attn_implementation=attention, + gradient_checkpointing=self.activation_checkpointing, )) else: self.model = MixtralModel(MixtralConfig( @@ -500,6 +501,7 @@ class Base(nn.Module): num_local_experts=n_experts, num_experts_per_tok=min(2, n_experts), attn_implementation=attention, + gradient_checkpointing=self.activation_checkpointing, )) elif self.arch_type == "llama": if n_experts <= 1: @@ -517,6 +519,7 @@ class Base(nn.Module): is_encoder_decoder=False, is_decoder=True, attn_implementation=attention, + gradient_checkpointing=self.activation_checkpointing, )) else: self.model = MixtralModel(MixtralConfig( @@ -536,7 +539,15 @@ class Base(nn.Module): num_local_experts=n_experts, num_experts_per_tok=min(2, n_experts), attn_implementation=attention, + gradient_checkpointing=self.activation_checkpointing, )) + + print("Checkpointing:", self.activation_checkpointing, self.model.gradient_checkpointing) + if self.activation_checkpointing and not self.model.gradient_checkpointing: + self.model.gradient_checkpointing_enable() + + if training: + self.model.training = True elif self.arch_type == "retnet": kwargs = dict( vocab_size=n_resp_tokens,