This commit is contained in:
mrq 2024-05-12 08:22:39 -05:00
parent 9910c75d5a
commit 917eeb40d2

View File

@ -562,8 +562,8 @@ class Base(nn.Module):
use_reentrant=False
))
if training:
self.model.training = True
#if training:
# self.model.training = True
elif self.arch_type == "llama":
if n_experts <= 1:
self.model = LlamaModel(LlamaConfig(
@ -608,8 +608,8 @@ class Base(nn.Module):
use_reentrant=False
))
if training:
self.model.training = True
#if training:
# self.model.training = True
elif self.arch_type == "retnet":
kwargs = dict(
vocab_size=n_resp_tokens,
@ -663,6 +663,11 @@ class Base(nn.Module):
)
self.model = RetNetDecoder_HF(RetNetConfig_HF(**kwargs))
if self.activation_checkpointing and not self.model.gradient_checkpointing:
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict(
use_reentrant=False
))
elif self.arch_type == "bitnet":
self.model = BitNetTransformer(
num_tokens=n_resp_tokens,
@ -675,11 +680,8 @@ class Base(nn.Module):
else:
raise RuntimeError(f'Unknown arch specified: {self.arch_type}')
# Disabling for now, it might be broken
"""
if self.config.attention in ["xformers", "auto", "mem_efficient", "math", "flash"]:
self.model = ml.replace_attention( self.model, klass=Llama_Attention, target=LlamaAttention, mode=self.config.attention )
"""
self.classifier = nn.Linear(d_model, n_resp_tokens)