diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 7ae1b9a..cded4ec 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -1671,9 +1671,9 @@ class Base(nn.Module): # mix if not nan if not torch.isnan(soft_loss).any(): - loss['kl'] = soft_loss * A for k in loss.keys(): loss[k] *= (1.0 - A) + loss['kl'] = soft_loss * A # include any additional losses (for example: MoE router) if output.loss is not None: