ughhh
This commit is contained in:
parent
9910c75d5a
commit
917eeb40d2
|
@ -562,8 +562,8 @@ class Base(nn.Module):
|
|||
use_reentrant=False
|
||||
))
|
||||
|
||||
if training:
|
||||
self.model.training = True
|
||||
#if training:
|
||||
# self.model.training = True
|
||||
elif self.arch_type == "llama":
|
||||
if n_experts <= 1:
|
||||
self.model = LlamaModel(LlamaConfig(
|
||||
|
@ -608,8 +608,8 @@ class Base(nn.Module):
|
|||
use_reentrant=False
|
||||
))
|
||||
|
||||
if training:
|
||||
self.model.training = True
|
||||
#if training:
|
||||
# self.model.training = True
|
||||
elif self.arch_type == "retnet":
|
||||
kwargs = dict(
|
||||
vocab_size=n_resp_tokens,
|
||||
|
@ -663,6 +663,11 @@ class Base(nn.Module):
|
|||
)
|
||||
|
||||
self.model = RetNetDecoder_HF(RetNetConfig_HF(**kwargs))
|
||||
|
||||
if self.activation_checkpointing and not self.model.gradient_checkpointing:
|
||||
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict(
|
||||
use_reentrant=False
|
||||
))
|
||||
elif self.arch_type == "bitnet":
|
||||
self.model = BitNetTransformer(
|
||||
num_tokens=n_resp_tokens,
|
||||
|
@ -675,11 +680,8 @@ class Base(nn.Module):
|
|||
else:
|
||||
raise RuntimeError(f'Unknown arch specified: {self.arch_type}')
|
||||
|
||||
# Disabling for now, it might be broken
|
||||
"""
|
||||
if self.config.attention in ["xformers", "auto", "mem_efficient", "math", "flash"]:
|
||||
self.model = ml.replace_attention( self.model, klass=Llama_Attention, target=LlamaAttention, mode=self.config.attention )
|
||||
"""
|
||||
|
||||
self.classifier = nn.Linear(d_model, n_resp_tokens)
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user