This commit is contained in:
mrq 2024-12-06 22:35:30 -06:00
parent 42fafbaaca
commit 953d3eb030

View File

@ -1671,9 +1671,9 @@ class Base(nn.Module):
# mix if not nan
if not torch.isnan(soft_loss).any():
loss['kl'] = soft_loss * A
for k in loss.keys():
loss[k] *= (1.0 - A)
loss['kl'] = soft_loss * A
# include any additional losses (for example: MoE router)
if output.loss is not None: