ugh
This commit is contained in:
parent
81acd565b3
commit
85b9dd47c1
|
@ -950,7 +950,7 @@ class Base_V2(nn.Module):
|
|||
if causal or self.predict_causally:
|
||||
l = self.causal_size
|
||||
logit = logit[..., :-l, :] # shift the target so that token n...
|
||||
token = sequence[..., l:] # ...predicts token n + 1
|
||||
token = token[..., l:] # ...predicts token n + 1
|
||||
|
||||
"""
|
||||
if self.logit_normalization:
|
||||
|
@ -972,7 +972,7 @@ class Base_V2(nn.Module):
|
|||
if causal or self.predict_causally:
|
||||
l = self.causal_size
|
||||
logit = logit[..., :-l, :] # shift the target so that token n...
|
||||
token = sequence[..., l:] # ...predicts token n + 1
|
||||
token = token[..., l:] # ...predicts token n + 1
|
||||
|
||||
"""
|
||||
if self.logit_normalization:
|
||||
|
|
Loading…
Reference in New Issue
Block a user