This commit is contained in:
mrq 2025-03-19 13:31:50 -05:00
parent 81acd565b3
commit 85b9dd47c1

View File

@ -950,7 +950,7 @@ class Base_V2(nn.Module):
if causal or self.predict_causally:
l = self.causal_size
logit = logit[..., :-l, :] # shift the target so that token n...
token = sequence[..., l:] # ...predicts token n + 1
token = token[..., l:] # ...predicts token n + 1
"""
if self.logit_normalization:
@ -972,7 +972,7 @@ class Base_V2(nn.Module):
if causal or self.predict_causally:
l = self.causal_size
logit = logit[..., :-l, :] # shift the target so that token n...
token = sequence[..., l:] # ...predicts token n + 1
token = token[..., l:] # ...predicts token n + 1
"""
if self.logit_normalization: