This commit is contained in:
mrq 2025-02-27 19:05:32 -06:00
parent eff180248c
commit 01e96bafc9

View File

@ -321,7 +321,7 @@ class Base_V2(nn.Module):
self.l_padding = 128 self.l_padding = 128
if self.arch_type in ["llama"]: if self.arch_type in ["llama"]:
self.model = LlamaModel_Adapted(LlamaConfig( self.model = LlamaModel(LlamaConfig(
vocab_size=n_vocab, vocab_size=n_vocab,
hidden_size=d_model, hidden_size=d_model,
max_position_embeddings=max_position_embeddings, max_position_embeddings=max_position_embeddings,
@ -337,8 +337,6 @@ class Base_V2(nn.Module):
#gradient_checkpointing=self.gradient_checkpointing, #gradient_checkpointing=self.gradient_checkpointing,
)) ))
self.model = ml.replace_attention( self.model, klass=LlamaAttention_Adapted, target=LlamaAttention, mode=attention_backend )
if self.gradient_checkpointing and not self.model.gradient_checkpointing: if self.gradient_checkpointing and not self.model.gradient_checkpointing:
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict( self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict(
use_reentrant=False use_reentrant=False