forked from mrq/DL-Art-School
Scale losses
This commit is contained in:
parent
a6397ce84a
commit
88ec0512f7
|
@ -7,6 +7,7 @@ import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from torch import distributed
|
||||||
from transformers.models.wav2vec2.modeling_wav2vec2 import _compute_mask_indices, _sample_negative_indices
|
from transformers.models.wav2vec2.modeling_wav2vec2 import _compute_mask_indices, _sample_negative_indices
|
||||||
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
||||||
|
|
||||||
|
@ -642,8 +643,18 @@ class ContrastiveTrainingWrapper(nn.Module):
|
||||||
num_codevectors = self.quantizer.num_codevectors
|
num_codevectors = self.quantizer.num_codevectors
|
||||||
diversity_loss = ((num_codevectors - codevector_perplexity) / num_codevectors) * mask_time_indices.sum()
|
diversity_loss = ((num_codevectors - codevector_perplexity) / num_codevectors) * mask_time_indices.sum()
|
||||||
|
|
||||||
|
num_losses = mask_time_indices.sum()
|
||||||
|
if distributed.is_initialized():
|
||||||
|
num_losses = distributed.reduce(num_losses) / distributed.get_world_size()
|
||||||
|
contrastive_loss = contrastive_loss / num_losses
|
||||||
|
diversity_loss = diversity_loss / num_losses
|
||||||
|
|
||||||
return contrastive_loss, diversity_loss
|
return contrastive_loss, diversity_loss
|
||||||
|
|
||||||
|
def before_step(self, it):
|
||||||
|
# Unscale the grads by the total number of losses encountered.
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def register_mel2vec_pretraining(opt_net, opt):
|
def register_mel2vec_pretraining(opt_net, opt):
|
||||||
|
@ -657,5 +668,5 @@ def register_mel2vec(opt_net, opt):
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
model = ContrastiveTrainingWrapper()
|
model = ContrastiveTrainingWrapper()
|
||||||
mel = torch.randn((2,256,400))
|
mel = torch.randn((2,256,401))
|
||||||
print(model(mel))
|
print(model(mel))
|
Loading…
Reference in New Issue
Block a user