ugh
This commit is contained in:
parent
42fafbaaca
commit
953d3eb030
|
@ -1671,9 +1671,9 @@ class Base(nn.Module):
|
|||
|
||||
# mix if not nan
|
||||
if not torch.isnan(soft_loss).any():
|
||||
loss['kl'] = soft_loss * A
|
||||
for k in loss.keys():
|
||||
loss[k] *= (1.0 - A)
|
||||
loss['kl'] = soft_loss * A
|
||||
|
||||
# include any additional losses (for example: MoE router)
|
||||
if output.loss is not None:
|
||||
|
|
Loading…
Reference in New Issue
Block a user