forked from mrq/DL-Art-School
Scale losses
This commit is contained in:
parent
a6397ce84a
commit
88ec0512f7
|
@ -7,6 +7,7 @@ import numpy as np
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import distributed
|
||||
from transformers.models.wav2vec2.modeling_wav2vec2 import _compute_mask_indices, _sample_negative_indices
|
||||
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
||||
|
||||
|
@ -642,8 +643,18 @@ class ContrastiveTrainingWrapper(nn.Module):
|
|||
num_codevectors = self.quantizer.num_codevectors
|
||||
diversity_loss = ((num_codevectors - codevector_perplexity) / num_codevectors) * mask_time_indices.sum()
|
||||
|
||||
num_losses = mask_time_indices.sum()
|
||||
if distributed.is_initialized():
|
||||
num_losses = distributed.reduce(num_losses) / distributed.get_world_size()
|
||||
contrastive_loss = contrastive_loss / num_losses
|
||||
diversity_loss = diversity_loss / num_losses
|
||||
|
||||
return contrastive_loss, diversity_loss
|
||||
|
||||
def before_step(self, it):
|
||||
# Unscale the grads by the total number of losses encountered.
|
||||
pass
|
||||
|
||||
|
||||
@register_model
|
||||
def register_mel2vec_pretraining(opt_net, opt):
|
||||
|
@ -657,5 +668,5 @@ def register_mel2vec(opt_net, opt):
|
|||
|
||||
if __name__ == '__main__':
|
||||
model = ContrastiveTrainingWrapper()
|
||||
mel = torch.randn((2,256,400))
|
||||
mel = torch.randn((2,256,401))
|
||||
print(model(mel))
|
Loading…
Reference in New Issue
Block a user