now causal training should work again

This commit is contained in:
mrq 2025-03-19 14:20:19 -05:00
parent 85b9dd47c1
commit 61de653ad9

View File

@ -941,24 +941,26 @@ class Base_V2(nn.Module):
if name != task_outputs.get(task_type, name): if name != task_outputs.get(task_type, name):
continue continue
sequence = token
if token.dim() == 1: if token.dim() == 1:
loss_factor = self.loss_factor(name) loss_factor = self.loss_factor(name)
if loss_factor == 0.0: if loss_factor == 0.0:
continue continue
logit = logits[batch_index][start:end] logit = logits[batch_index][start:end]
if causal or self.predict_causally:
l = self.causal_size
logit = logit[..., :-l, :] # shift the target so that token n...
token = token[..., l:] # ...predicts token n + 1
""" """
if self.logit_normalization: if self.logit_normalization:
logit = logit_normalization( logit, self.logit_normalization ) logit = logit_normalization( logit, self.logit_normalization )
""" """
loss_targets.append( token.long() ) if causal or self.predict_causally:
loss_logits.append( logit ) l = self.causal_size
loss_targets.append( token[l:].long() ) # shift the target so that token n...
loss_logits.append( logit[..., :-l, :] ) # ...predicts token n + 1
else:
loss_targets.append( token.long() )
loss_logits.append( logit )
loss_factors.append( loss_factor ) loss_factors.append( loss_factor )
loss_names.append( name ) loss_names.append( name )
else: else:
@ -969,18 +971,21 @@ class Base_V2(nn.Module):
continue continue
logit = logits[batch_index][level][start:end] logit = logits[batch_index][level][start:end]
if causal or self.predict_causally:
l = self.causal_size
logit = logit[..., :-l, :] # shift the target so that token n...
token = token[..., l:] # ...predicts token n + 1
""" """
if self.logit_normalization: if self.logit_normalization:
logit = logit_normalization( logit, self.logit_normalization ) logit = logit_normalization( logit, self.logit_normalization )
""" """
loss_targets.append( token[:, level].long() ) if causal or self.predict_causally:
loss_logits.append( logit ) l = self.causal_size
loss_targets.append( token[l:, level].long() ) # shift the target so that token n...
loss_logits.append( logit[..., :-l, :] ) # ...predicts token n + 1
else:
loss_targets.append( token[:, level].long() )
loss_logits.append( logit )
loss_factors.append( level_loss_factors[level] ) loss_factors.append( level_loss_factors[level] )
loss_names.append( name ) loss_names.append( name )