From 01e96bafc957e7143774e4514e7d0e280106393c Mon Sep 17 00:00:00 2001 From: mrq Date: Thu, 27 Feb 2025 19:05:32 -0600 Subject: [PATCH] ugh --- vall_e/models/base_v2.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/vall_e/models/base_v2.py b/vall_e/models/base_v2.py index a8ee428..f185dea 100644 --- a/vall_e/models/base_v2.py +++ b/vall_e/models/base_v2.py @@ -321,7 +321,7 @@ class Base_V2(nn.Module): self.l_padding = 128 if self.arch_type in ["llama"]: - self.model = LlamaModel_Adapted(LlamaConfig( + self.model = LlamaModel(LlamaConfig( vocab_size=n_vocab, hidden_size=d_model, max_position_embeddings=max_position_embeddings, @@ -337,8 +337,6 @@ class Base_V2(nn.Module): #gradient_checkpointing=self.gradient_checkpointing, )) - self.model = ml.replace_attention( self.model, klass=LlamaAttention_Adapted, target=LlamaAttention, mode=attention_backend ) - if self.gradient_checkpointing and not self.model.gradient_checkpointing: self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict( use_reentrant=False