ugh
This commit is contained in:
parent
42fafbaaca
commit
953d3eb030
|
@ -1671,9 +1671,9 @@ class Base(nn.Module):
|
||||||
|
|
||||||
# mix if not nan
|
# mix if not nan
|
||||||
if not torch.isnan(soft_loss).any():
|
if not torch.isnan(soft_loss).any():
|
||||||
loss['kl'] = soft_loss * A
|
|
||||||
for k in loss.keys():
|
for k in loss.keys():
|
||||||
loss[k] *= (1.0 - A)
|
loss[k] *= (1.0 - A)
|
||||||
|
loss['kl'] = soft_loss * A
|
||||||
|
|
||||||
# include any additional losses (for example: MoE router)
|
# include any additional losses (for example: MoE router)
|
||||||
if output.loss is not None:
|
if output.loss is not None:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user