ugh
This commit is contained in:
parent
eff180248c
commit
01e96bafc9
|
@ -321,7 +321,7 @@ class Base_V2(nn.Module):
|
|||
self.l_padding = 128
|
||||
|
||||
if self.arch_type in ["llama"]:
|
||||
self.model = LlamaModel_Adapted(LlamaConfig(
|
||||
self.model = LlamaModel(LlamaConfig(
|
||||
vocab_size=n_vocab,
|
||||
hidden_size=d_model,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
|
@ -337,8 +337,6 @@ class Base_V2(nn.Module):
|
|||
#gradient_checkpointing=self.gradient_checkpointing,
|
||||
))
|
||||
|
||||
self.model = ml.replace_attention( self.model, klass=LlamaAttention_Adapted, target=LlamaAttention, mode=attention_backend )
|
||||
|
||||
if self.gradient_checkpointing and not self.model.gradient_checkpointing:
|
||||
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict(
|
||||
use_reentrant=False
|
||||
|
|
Loading…
Reference in New Issue
Block a user