diff --git a/codes/models/audio/mel2vec.py b/codes/models/audio/mel2vec.py index 285cf299..c8ff6655 100644 --- a/codes/models/audio/mel2vec.py +++ b/codes/models/audio/mel2vec.py @@ -647,8 +647,8 @@ class ContrastiveTrainingWrapper(nn.Module): if distributed.is_initialized(): distributed.all_reduce(num_losses) num_losses = num_losses / distributed.get_world_size() - contrastive_loss = contrastive_loss / num_losses - diversity_loss = diversity_loss / num_losses + contrastive_loss = contrastive_loss + diversity_loss = diversity_loss self.num_losses_record = num_losses.detach() return contrastive_loss, diversity_loss