now causal training should work again
This commit is contained in:
parent
85b9dd47c1
commit
61de653ad9
|
@ -941,24 +941,26 @@ class Base_V2(nn.Module):
|
|||
if name != task_outputs.get(task_type, name):
|
||||
continue
|
||||
|
||||
sequence = token
|
||||
if token.dim() == 1:
|
||||
loss_factor = self.loss_factor(name)
|
||||
if loss_factor == 0.0:
|
||||
continue
|
||||
|
||||
logit = logits[batch_index][start:end]
|
||||
if causal or self.predict_causally:
|
||||
l = self.causal_size
|
||||
logit = logit[..., :-l, :] # shift the target so that token n...
|
||||
token = token[..., l:] # ...predicts token n + 1
|
||||
|
||||
|
||||
"""
|
||||
if self.logit_normalization:
|
||||
logit = logit_normalization( logit, self.logit_normalization )
|
||||
"""
|
||||
|
||||
loss_targets.append( token.long() )
|
||||
loss_logits.append( logit )
|
||||
if causal or self.predict_causally:
|
||||
l = self.causal_size
|
||||
loss_targets.append( token[l:].long() ) # shift the target so that token n...
|
||||
loss_logits.append( logit[..., :-l, :] ) # ...predicts token n + 1
|
||||
else:
|
||||
loss_targets.append( token.long() )
|
||||
loss_logits.append( logit )
|
||||
loss_factors.append( loss_factor )
|
||||
loss_names.append( name )
|
||||
else:
|
||||
|
@ -969,18 +971,21 @@ class Base_V2(nn.Module):
|
|||
continue
|
||||
|
||||
logit = logits[batch_index][level][start:end]
|
||||
if causal or self.predict_causally:
|
||||
l = self.causal_size
|
||||
logit = logit[..., :-l, :] # shift the target so that token n...
|
||||
token = token[..., l:] # ...predicts token n + 1
|
||||
|
||||
"""
|
||||
if self.logit_normalization:
|
||||
logit = logit_normalization( logit, self.logit_normalization )
|
||||
"""
|
||||
|
||||
loss_targets.append( token[:, level].long() )
|
||||
loss_logits.append( logit )
|
||||
if causal or self.predict_causally:
|
||||
l = self.causal_size
|
||||
loss_targets.append( token[l:, level].long() ) # shift the target so that token n...
|
||||
loss_logits.append( logit[..., :-l, :] ) # ...predicts token n + 1
|
||||
else:
|
||||
loss_targets.append( token[:, level].long() )
|
||||
loss_logits.append( logit )
|
||||
|
||||
|
||||
loss_factors.append( level_loss_factors[level] )
|
||||
loss_names.append( name )
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user