From 88ec0512f7f9572261c0a4ff140fef1ccc382452 Mon Sep 17 00:00:00 2001 From: James Betker Date: Tue, 17 May 2022 17:12:20 -0600 Subject: [PATCH] Scale losses --- codes/models/audio/mel2vec.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/codes/models/audio/mel2vec.py b/codes/models/audio/mel2vec.py index c1721a2b..83cf943a 100644 --- a/codes/models/audio/mel2vec.py +++ b/codes/models/audio/mel2vec.py @@ -7,6 +7,7 @@ import numpy as np import torch import torch.nn as nn import torch.nn.functional as F +from torch import distributed from transformers.models.wav2vec2.modeling_wav2vec2 import _compute_mask_indices, _sample_negative_indices from transformers.deepspeed import is_deepspeed_zero3_enabled @@ -642,8 +643,18 @@ class ContrastiveTrainingWrapper(nn.Module): num_codevectors = self.quantizer.num_codevectors diversity_loss = ((num_codevectors - codevector_perplexity) / num_codevectors) * mask_time_indices.sum() + num_losses = mask_time_indices.sum() + if distributed.is_initialized(): + num_losses = distributed.reduce(num_losses) / distributed.get_world_size() + contrastive_loss = contrastive_loss / num_losses + diversity_loss = diversity_loss / num_losses + return contrastive_loss, diversity_loss + def before_step(self, it): + # Unscale the grads by the total number of losses encountered. + pass + @register_model def register_mel2vec_pretraining(opt_net, opt): @@ -657,5 +668,5 @@ def register_mel2vec(opt_net, opt): if __name__ == '__main__': model = ContrastiveTrainingWrapper() - mel = torch.randn((2,256,400)) + mel = torch.randn((2,256,401)) print(model(mel)) \ No newline at end of file