diff --git a/codes/models/audio/mel2vec.py b/codes/models/audio/mel2vec.py index 83cf943a..ea0fff6b 100644 --- a/codes/models/audio/mel2vec.py +++ b/codes/models/audio/mel2vec.py @@ -645,7 +645,7 @@ class ContrastiveTrainingWrapper(nn.Module): num_losses = mask_time_indices.sum() if distributed.is_initialized(): - num_losses = distributed.reduce(num_losses) / distributed.get_world_size() + num_losses = distributed.all_reduce(num_losses) / distributed.get_world_size() contrastive_loss = contrastive_loss / num_losses diversity_loss = diversity_loss / num_losses