now causal training should work again
This commit is contained in:
parent
85b9dd47c1
commit
61de653ad9
|
@ -941,24 +941,26 @@ class Base_V2(nn.Module):
|
||||||
if name != task_outputs.get(task_type, name):
|
if name != task_outputs.get(task_type, name):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
sequence = token
|
||||||
if token.dim() == 1:
|
if token.dim() == 1:
|
||||||
loss_factor = self.loss_factor(name)
|
loss_factor = self.loss_factor(name)
|
||||||
if loss_factor == 0.0:
|
if loss_factor == 0.0:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
logit = logits[batch_index][start:end]
|
logit = logits[batch_index][start:end]
|
||||||
if causal or self.predict_causally:
|
|
||||||
l = self.causal_size
|
|
||||||
logit = logit[..., :-l, :] # shift the target so that token n...
|
|
||||||
token = token[..., l:] # ...predicts token n + 1
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if self.logit_normalization:
|
if self.logit_normalization:
|
||||||
logit = logit_normalization( logit, self.logit_normalization )
|
logit = logit_normalization( logit, self.logit_normalization )
|
||||||
"""
|
"""
|
||||||
|
|
||||||
loss_targets.append( token.long() )
|
if causal or self.predict_causally:
|
||||||
loss_logits.append( logit )
|
l = self.causal_size
|
||||||
|
loss_targets.append( token[l:].long() ) # shift the target so that token n...
|
||||||
|
loss_logits.append( logit[..., :-l, :] ) # ...predicts token n + 1
|
||||||
|
else:
|
||||||
|
loss_targets.append( token.long() )
|
||||||
|
loss_logits.append( logit )
|
||||||
loss_factors.append( loss_factor )
|
loss_factors.append( loss_factor )
|
||||||
loss_names.append( name )
|
loss_names.append( name )
|
||||||
else:
|
else:
|
||||||
|
@ -969,18 +971,21 @@ class Base_V2(nn.Module):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
logit = logits[batch_index][level][start:end]
|
logit = logits[batch_index][level][start:end]
|
||||||
if causal or self.predict_causally:
|
|
||||||
l = self.causal_size
|
|
||||||
logit = logit[..., :-l, :] # shift the target so that token n...
|
|
||||||
token = token[..., l:] # ...predicts token n + 1
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if self.logit_normalization:
|
if self.logit_normalization:
|
||||||
logit = logit_normalization( logit, self.logit_normalization )
|
logit = logit_normalization( logit, self.logit_normalization )
|
||||||
"""
|
"""
|
||||||
|
|
||||||
loss_targets.append( token[:, level].long() )
|
if causal or self.predict_causally:
|
||||||
loss_logits.append( logit )
|
l = self.causal_size
|
||||||
|
loss_targets.append( token[l:, level].long() ) # shift the target so that token n...
|
||||||
|
loss_logits.append( logit[..., :-l, :] ) # ...predicts token n + 1
|
||||||
|
else:
|
||||||
|
loss_targets.append( token[:, level].long() )
|
||||||
|
loss_logits.append( logit )
|
||||||
|
|
||||||
|
|
||||||
loss_factors.append( level_loss_factors[level] )
|
loss_factors.append( level_loss_factors[level] )
|
||||||
loss_names.append( name )
|
loss_names.append( name )
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user