This commit is contained in:
James Betker 2022-05-17 18:04:20 -06:00
parent 7213ad2b89
commit 6130391a85

View File

@ -647,8 +647,8 @@ class ContrastiveTrainingWrapper(nn.Module):
if distributed.is_initialized():
distributed.all_reduce(num_losses)
num_losses = num_losses / distributed.get_world_size()
contrastive_loss = contrastive_loss / num_losses
diversity_loss = diversity_loss / num_losses
contrastive_loss = contrastive_loss
diversity_loss = diversity_loss
self.num_losses_record = num_losses.detach()
return contrastive_loss, diversity_loss