From 917eeb40d212d81e209eae81892d2aa2a108139f Mon Sep 17 00:00:00 2001 From: mrq Date: Sun, 12 May 2024 08:22:39 -0500 Subject: [PATCH] ughhh --- vall_e/models/base.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 54c0070..ae79f61 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -562,8 +562,8 @@ class Base(nn.Module): use_reentrant=False )) - if training: - self.model.training = True + #if training: + # self.model.training = True elif self.arch_type == "llama": if n_experts <= 1: self.model = LlamaModel(LlamaConfig( @@ -608,8 +608,8 @@ class Base(nn.Module): use_reentrant=False )) - if training: - self.model.training = True + #if training: + # self.model.training = True elif self.arch_type == "retnet": kwargs = dict( vocab_size=n_resp_tokens, @@ -663,6 +663,11 @@ class Base(nn.Module): ) self.model = RetNetDecoder_HF(RetNetConfig_HF(**kwargs)) + + if self.activation_checkpointing and not self.model.gradient_checkpointing: + self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict( + use_reentrant=False + )) elif self.arch_type == "bitnet": self.model = BitNetTransformer( num_tokens=n_resp_tokens, @@ -675,11 +680,8 @@ class Base(nn.Module): else: raise RuntimeError(f'Unknown arch specified: {self.arch_type}') - # Disabling for now, it might be broken - """ if self.config.attention in ["xformers", "auto", "mem_efficient", "math", "flash"]: self.model = ml.replace_attention( self.model, klass=Llama_Attention, target=LlamaAttention, mode=self.config.attention ) - """ self.classifier = nn.Linear(d_model, n_resp_tokens)