Scale losses

This commit is contained in:
James Betker 2022-05-17 17:12:20 -06:00
parent a6397ce84a
commit 88ec0512f7

View File

@ -7,6 +7,7 @@ import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import distributed
from transformers.models.wav2vec2.modeling_wav2vec2 import _compute_mask_indices, _sample_negative_indices
from transformers.deepspeed import is_deepspeed_zero3_enabled
@ -642,8 +643,18 @@ class ContrastiveTrainingWrapper(nn.Module):
num_codevectors = self.quantizer.num_codevectors
diversity_loss = ((num_codevectors - codevector_perplexity) / num_codevectors) * mask_time_indices.sum()
num_losses = mask_time_indices.sum()
if distributed.is_initialized():
num_losses = distributed.reduce(num_losses) / distributed.get_world_size()
contrastive_loss = contrastive_loss / num_losses
diversity_loss = diversity_loss / num_losses
return contrastive_loss, diversity_loss
def before_step(self, it):
# Unscale the grads by the total number of losses encountered.
pass
@register_model
def register_mel2vec_pretraining(opt_net, opt):
@ -657,5 +668,5 @@ def register_mel2vec(opt_net, opt):
if __name__ == '__main__':
model = ContrastiveTrainingWrapper()
mel = torch.randn((2,256,400))
mel = torch.randn((2,256,401))
print(model(mel))