This commit is contained in:
mrq 2025-02-28 01:04:24 -06:00
parent e6d421a3aa
commit 180a4eac1b

View File

@ -376,7 +376,9 @@ class Base_V2(nn.Module):
elif attention_backend == "fused_attn": elif attention_backend == "fused_attn":
self.l_padding = 128 self.l_padding = 128
if self.arch_type in ["llama"]: if self.arch_type in ["none"]:
self.model = None
elif self.arch_type in ["llama"]:
self.model_config = LlamaConfig( self.model_config = LlamaConfig(
vocab_size=n_vocab, vocab_size=n_vocab,
hidden_size=d_model, hidden_size=d_model,
@ -423,8 +425,10 @@ class Base_V2(nn.Module):
attentions = None attentions = None
hidden_states = None hidden_states = None
if self.arch_type in ["none"] or self.model is None:
...
# HF transformer derived model # HF transformer derived model
if self.arch_type in ["llama"]: elif self.arch_type in ["llama"]:
kwargs = dict( kwargs = dict(
inputs_embeds=x, inputs_embeds=x,
attention_mask=m, attention_mask=m,
@ -988,7 +992,7 @@ class Base_V2(nn.Module):
nll, metrics = _calc_loss( logit, sequence, causal, level ) nll, metrics = _calc_loss( logit, sequence, causal, level )
nlls.append( nll ) nlls.append( nll )
accs.append( accs ) accs.append( acc )
nll = sum(nlls) nll = sum(nlls)
accs = mean(accs) accs = mean(accs)