ughh
This commit is contained in:
parent
e6d421a3aa
commit
180a4eac1b
|
@ -376,7 +376,9 @@ class Base_V2(nn.Module):
|
||||||
elif attention_backend == "fused_attn":
|
elif attention_backend == "fused_attn":
|
||||||
self.l_padding = 128
|
self.l_padding = 128
|
||||||
|
|
||||||
if self.arch_type in ["llama"]:
|
if self.arch_type in ["none"]:
|
||||||
|
self.model = None
|
||||||
|
elif self.arch_type in ["llama"]:
|
||||||
self.model_config = LlamaConfig(
|
self.model_config = LlamaConfig(
|
||||||
vocab_size=n_vocab,
|
vocab_size=n_vocab,
|
||||||
hidden_size=d_model,
|
hidden_size=d_model,
|
||||||
|
@ -423,8 +425,10 @@ class Base_V2(nn.Module):
|
||||||
attentions = None
|
attentions = None
|
||||||
hidden_states = None
|
hidden_states = None
|
||||||
|
|
||||||
|
if self.arch_type in ["none"] or self.model is None:
|
||||||
|
...
|
||||||
# HF transformer derived model
|
# HF transformer derived model
|
||||||
if self.arch_type in ["llama"]:
|
elif self.arch_type in ["llama"]:
|
||||||
kwargs = dict(
|
kwargs = dict(
|
||||||
inputs_embeds=x,
|
inputs_embeds=x,
|
||||||
attention_mask=m,
|
attention_mask=m,
|
||||||
|
@ -988,7 +992,7 @@ class Base_V2(nn.Module):
|
||||||
nll, metrics = _calc_loss( logit, sequence, causal, level )
|
nll, metrics = _calc_loss( logit, sequence, causal, level )
|
||||||
|
|
||||||
nlls.append( nll )
|
nlls.append( nll )
|
||||||
accs.append( accs )
|
accs.append( acc )
|
||||||
|
|
||||||
nll = sum(nlls)
|
nll = sum(nlls)
|
||||||
accs = mean(accs)
|
accs = mean(accs)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user