darn mpi
This commit is contained in:
parent
88ec0512f7
commit
7c82e18c6c
|
@ -645,7 +645,7 @@ class ContrastiveTrainingWrapper(nn.Module):
|
|||
|
||||
num_losses = mask_time_indices.sum()
|
||||
if distributed.is_initialized():
|
||||
num_losses = distributed.reduce(num_losses) / distributed.get_world_size()
|
||||
num_losses = distributed.all_reduce(num_losses) / distributed.get_world_size()
|
||||
contrastive_loss = contrastive_loss / num_losses
|
||||
diversity_loss = diversity_loss / num_losses
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user