ughh
This commit is contained in:
parent
e6d421a3aa
commit
180a4eac1b
|
@ -376,7 +376,9 @@ class Base_V2(nn.Module):
|
|||
elif attention_backend == "fused_attn":
|
||||
self.l_padding = 128
|
||||
|
||||
if self.arch_type in ["llama"]:
|
||||
if self.arch_type in ["none"]:
|
||||
self.model = None
|
||||
elif self.arch_type in ["llama"]:
|
||||
self.model_config = LlamaConfig(
|
||||
vocab_size=n_vocab,
|
||||
hidden_size=d_model,
|
||||
|
@ -423,8 +425,10 @@ class Base_V2(nn.Module):
|
|||
attentions = None
|
||||
hidden_states = None
|
||||
|
||||
if self.arch_type in ["none"] or self.model is None:
|
||||
...
|
||||
# HF transformer derived model
|
||||
if self.arch_type in ["llama"]:
|
||||
elif self.arch_type in ["llama"]:
|
||||
kwargs = dict(
|
||||
inputs_embeds=x,
|
||||
attention_mask=m,
|
||||
|
@ -988,7 +992,7 @@ class Base_V2(nn.Module):
|
|||
nll, metrics = _calc_loss( logit, sequence, causal, level )
|
||||
|
||||
nlls.append( nll )
|
||||
accs.append( accs )
|
||||
accs.append( acc )
|
||||
|
||||
nll = sum(nlls)
|
||||
accs = mean(accs)
|
||||
|
|
Loading…
Reference in New Issue
Block a user