This commit is contained in:
mrq 2025-02-28 01:04:24 -06:00
parent e6d421a3aa
commit 180a4eac1b

View File

@ -376,7 +376,9 @@ class Base_V2(nn.Module):
elif attention_backend == "fused_attn":
self.l_padding = 128
if self.arch_type in ["llama"]:
if self.arch_type in ["none"]:
self.model = None
elif self.arch_type in ["llama"]:
self.model_config = LlamaConfig(
vocab_size=n_vocab,
hidden_size=d_model,
@ -423,8 +425,10 @@ class Base_V2(nn.Module):
attentions = None
hidden_states = None
if self.arch_type in ["none"] or self.model is None:
...
# HF transformer derived model
if self.arch_type in ["llama"]:
elif self.arch_type in ["llama"]:
kwargs = dict(
inputs_embeds=x,
attention_mask=m,
@ -988,7 +992,7 @@ class Base_V2(nn.Module):
nll, metrics = _calc_loss( logit, sequence, causal, level )
nlls.append( nll )
accs.append( accs )
accs.append( acc )
nll = sum(nlls)
accs = mean(accs)