just take the mean...

This commit is contained in:
James Betker 2022-05-17 18:09:23 -06:00
parent 6130391a85
commit ee364f4eeb

View File

@ -638,27 +638,29 @@ class ContrastiveTrainingWrapper(nn.Module):
logits = logits.transpose(0, 2).reshape(-1, logits.size(0))
target = ((1 - mask_time_indices.long()) * -100).transpose(0, 1).flatten()
contrastive_loss = nn.functional.cross_entropy(logits.float(), target, reduction="sum")
contrastive_loss = nn.functional.cross_entropy(logits.float(), target, reduction="mean")
# 7. compute diversity loss: \mathbf{L}_d
num_codevectors = self.quantizer.num_codevectors
diversity_loss = ((num_codevectors - codevector_perplexity) / num_codevectors) * mask_time_indices.sum()
diversity_loss = (num_codevectors - codevector_perplexity) / num_codevectors
"""
num_losses = mask_time_indices.sum()
if distributed.is_initialized():
distributed.all_reduce(num_losses)
num_losses = num_losses / distributed.get_world_size()
contrastive_loss = contrastive_loss
diversity_loss = diversity_loss
self.num_losses_record = num_losses.detach()
"""
return contrastive_loss, diversity_loss
"""
def after_backward(self, it):
if self.num_losses_record > 0:
# Unscale the grads by the total number of losses encountered.
for p in self.parameters():
if p.grad is not None:
p.grad.data.div_(self.num_losses_record)
"""
@register_model