darn mpi
This commit is contained in:
parent
88ec0512f7
commit
7c82e18c6c
|
@ -645,7 +645,7 @@ class ContrastiveTrainingWrapper(nn.Module):
|
||||||
|
|
||||||
num_losses = mask_time_indices.sum()
|
num_losses = mask_time_indices.sum()
|
||||||
if distributed.is_initialized():
|
if distributed.is_initialized():
|
||||||
num_losses = distributed.reduce(num_losses) / distributed.get_world_size()
|
num_losses = distributed.all_reduce(num_losses) / distributed.get_world_size()
|
||||||
contrastive_loss = contrastive_loss / num_losses
|
contrastive_loss = contrastive_loss / num_losses
|
||||||
diversity_loss = diversity_loss / num_losses
|
diversity_loss = diversity_loss / num_losses
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user