diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 7cb24b0..d9a3127 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -365,7 +365,7 @@ def example_usage(): 'n_tokens': 1024, 'd_model': 1024, # 256, # 1024, # 1536 'n_heads': 16, # 4, # 16, # 24 - 'n_layers': 4, # 32 + 'n_layers': 12, # 32 'n_experts': 1, 'l_padding': 8 if cfg.optimizations.fp8 else 0, diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 52ca193..55a866e 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -481,7 +481,7 @@ class Base(nn.Module): is_encoder_decoder=False, is_decoder=True, attn_implementation=attention, - gradient_checkpointing=self.activation_checkpointing, + #gradient_checkpointing=self.activation_checkpointing, )) else: self.model = MixtralModel(MixtralConfig( @@ -501,8 +501,14 @@ class Base(nn.Module): num_local_experts=n_experts, num_experts_per_tok=min(2, n_experts), attn_implementation=attention, - gradient_checkpointing=self.activation_checkpointing, + #gradient_checkpointing=self.activation_checkpointing, )) + + if self.activation_checkpointing and not self.model.gradient_checkpointing: + self.model.gradient_checkpointing_enable() + + if training: + self.model.training = True elif self.arch_type == "llama": if n_experts <= 1: self.model = LlamaModel(LlamaConfig( @@ -519,7 +525,7 @@ class Base(nn.Module): is_encoder_decoder=False, is_decoder=True, attn_implementation=attention, - gradient_checkpointing=self.activation_checkpointing, + #gradient_checkpointing=self.activation_checkpointing, )) else: self.model = MixtralModel(MixtralConfig( @@ -539,10 +545,9 @@ class Base(nn.Module): num_local_experts=n_experts, num_experts_per_tok=min(2, n_experts), attn_implementation=attention, - gradient_checkpointing=self.activation_checkpointing, + #gradient_checkpointing=self.activation_checkpointing, )) - print("Checkpointing:", self.activation_checkpointing, self.model.gradient_checkpointing) if self.activation_checkpointing and not self.model.gradient_checkpointing: self.model.gradient_checkpointing_enable()