This commit is contained in:
James Betker 2022-05-17 17:16:09 -06:00
parent 88ec0512f7
commit 7c82e18c6c

View File

@ -645,7 +645,7 @@ class ContrastiveTrainingWrapper(nn.Module):
num_losses = mask_time_indices.sum() num_losses = mask_time_indices.sum()
if distributed.is_initialized(): if distributed.is_initialized():
num_losses = distributed.reduce(num_losses) / distributed.get_world_size() num_losses = distributed.all_reduce(num_losses) / distributed.get_world_size()
contrastive_loss = contrastive_loss / num_losses contrastive_loss = contrastive_loss / num_losses
diversity_loss = diversity_loss / num_losses diversity_loss = diversity_loss / num_losses