From ee364f4eebbe1eced8dc8d0240a0694aca7fa626 Mon Sep 17 00:00:00 2001 From: James Betker Date: Tue, 17 May 2022 18:09:23 -0600 Subject: [PATCH] just take the mean... --- codes/models/audio/mel2vec.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/codes/models/audio/mel2vec.py b/codes/models/audio/mel2vec.py index c8ff6655..6e768cd3 100644 --- a/codes/models/audio/mel2vec.py +++ b/codes/models/audio/mel2vec.py @@ -638,27 +638,29 @@ class ContrastiveTrainingWrapper(nn.Module): logits = logits.transpose(0, 2).reshape(-1, logits.size(0)) target = ((1 - mask_time_indices.long()) * -100).transpose(0, 1).flatten() - contrastive_loss = nn.functional.cross_entropy(logits.float(), target, reduction="sum") + contrastive_loss = nn.functional.cross_entropy(logits.float(), target, reduction="mean") # 7. compute diversity loss: \mathbf{L}_d num_codevectors = self.quantizer.num_codevectors - diversity_loss = ((num_codevectors - codevector_perplexity) / num_codevectors) * mask_time_indices.sum() + diversity_loss = (num_codevectors - codevector_perplexity) / num_codevectors + """ num_losses = mask_time_indices.sum() if distributed.is_initialized(): distributed.all_reduce(num_losses) num_losses = num_losses / distributed.get_world_size() - contrastive_loss = contrastive_loss - diversity_loss = diversity_loss - self.num_losses_record = num_losses.detach() + """ + return contrastive_loss, diversity_loss + """ def after_backward(self, it): if self.num_losses_record > 0: # Unscale the grads by the total number of losses encountered. for p in self.parameters(): if p.grad is not None: p.grad.data.div_(self.num_losses_record) + """ @register_model