fix div
This commit is contained in:
parent
7213ad2b89
commit
6130391a85
|
@ -647,8 +647,8 @@ class ContrastiveTrainingWrapper(nn.Module):
|
|||
if distributed.is_initialized():
|
||||
distributed.all_reduce(num_losses)
|
||||
num_losses = num_losses / distributed.get_world_size()
|
||||
contrastive_loss = contrastive_loss / num_losses
|
||||
diversity_loss = diversity_loss / num_losses
|
||||
contrastive_loss = contrastive_loss
|
||||
diversity_loss = diversity_loss
|
||||
|
||||
self.num_losses_record = num_losses.detach()
|
||||
return contrastive_loss, diversity_loss
|
||||
|
|
Loading…
Reference in New Issue
Block a user