ugh
This commit is contained in:
parent
81acd565b3
commit
85b9dd47c1
|
@ -950,7 +950,7 @@ class Base_V2(nn.Module):
|
||||||
if causal or self.predict_causally:
|
if causal or self.predict_causally:
|
||||||
l = self.causal_size
|
l = self.causal_size
|
||||||
logit = logit[..., :-l, :] # shift the target so that token n...
|
logit = logit[..., :-l, :] # shift the target so that token n...
|
||||||
token = sequence[..., l:] # ...predicts token n + 1
|
token = token[..., l:] # ...predicts token n + 1
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if self.logit_normalization:
|
if self.logit_normalization:
|
||||||
|
@ -972,7 +972,7 @@ class Base_V2(nn.Module):
|
||||||
if causal or self.predict_causally:
|
if causal or self.predict_causally:
|
||||||
l = self.causal_size
|
l = self.causal_size
|
||||||
logit = logit[..., :-l, :] # shift the target so that token n...
|
logit = logit[..., :-l, :] # shift the target so that token n...
|
||||||
token = sequence[..., l:] # ...predicts token n + 1
|
token = token[..., l:] # ...predicts token n + 1
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if self.logit_normalization:
|
if self.logit_normalization:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user