m2v frequency masking

This commit is contained in:
James Betker 2022-05-23 07:04:12 -06:00
parent 4093e38717
commit 68c0afcbcc

View File

@ -1,6 +1,7 @@
import copy import copy
import functools import functools
import math import math
import random
from typing import Optional, Tuple from typing import Optional, Tuple
import numpy as np import numpy as np
@ -569,7 +570,7 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module):
class ContrastiveTrainingWrapper(nn.Module): class ContrastiveTrainingWrapper(nn.Module):
def __init__(self, inner_dim=1024, dropout=.1, mask_time_prob=.65, mask_time_length=6, num_negatives=100, def __init__(self, inner_dim=1024, dropout=.1, mask_time_prob=.65, mask_time_length=6, num_negatives=100,
max_gumbel_temperature=2.0, min_gumbel_temperature=.5, gumbel_temperature_decay=.999995, max_gumbel_temperature=2.0, min_gumbel_temperature=.5, gumbel_temperature_decay=.999995,
codebook_size=320, codebook_groups=2, codebook_size=320, codebook_groups=2, freq_mask_percent=0,
**kwargs): **kwargs):
super().__init__() super().__init__()
self.m2v = Mel2Vec(inner_dim=inner_dim, dropout=dropout, mask_time_prob=mask_time_prob, self.m2v = Mel2Vec(inner_dim=inner_dim, dropout=dropout, mask_time_prob=mask_time_prob,
@ -580,6 +581,7 @@ class ContrastiveTrainingWrapper(nn.Module):
self.max_gumbel_temperature = max_gumbel_temperature self.max_gumbel_temperature = max_gumbel_temperature
self.min_gumbel_temperature = min_gumbel_temperature self.min_gumbel_temperature = min_gumbel_temperature
self.gumbel_temperature_decay = gumbel_temperature_decay self.gumbel_temperature_decay = gumbel_temperature_decay
self.freq_mask_percent = freq_mask_percent
self.quantizer = Wav2Vec2GumbelVectorQuantizer(inner_dim, num_codevector_groups=codebook_groups, num_codevectors_per_group=codebook_size) self.quantizer = Wav2Vec2GumbelVectorQuantizer(inner_dim, num_codevector_groups=codebook_groups, num_codevectors_per_group=codebook_size)
self.num_losses_record = [] self.num_losses_record = []
@ -631,6 +633,12 @@ class ContrastiveTrainingWrapper(nn.Module):
def forward(self, mel): def forward(self, mel):
mel = mel[:, :, :-1] # The MEL computation always pads with 1, throwing off optimal tensor math. mel = mel[:, :, :-1] # The MEL computation always pads with 1, throwing off optimal tensor math.
# Frequency masking
freq_mask_width = int(random.random() * self.freq_mask_percent * mel.shape[1])
if freq_mask_width >= 2:
freq_start = random.randint(0, mel.shape[1]-freq_mask_width)
mel[:, freq_start:freq_start+freq_mask_width] = 0
features_shape = (mel.shape[0], mel.shape[-1]//self.m2v.dim_reduction_mult) features_shape = (mel.shape[0], mel.shape[-1]//self.m2v.dim_reduction_mult)
mask_time_indices = _compute_mask_indices(features_shape, self.mask_time_prob, self.mask_time_length) mask_time_indices = _compute_mask_indices(features_shape, self.mask_time_prob, self.mask_time_length)
sampled_negative_indices = torch.tensor(_sample_negative_indices(features_shape, self.num_negatives, mask_time_indices=mask_time_indices), device=mel.device) sampled_negative_indices = torch.tensor(_sample_negative_indices(features_shape, self.num_negatives, mask_time_indices=mask_time_indices), device=mel.device)
@ -698,6 +706,6 @@ def register_mel2vec(opt_net, opt):
if __name__ == '__main__': if __name__ == '__main__':
model = ContrastiveTrainingWrapper() model = ContrastiveTrainingWrapper(freq_mask_percent=.5)
mel = torch.randn((2,256,401)) mel = torch.randn((2,256,401))
print(model(mel)) print(model(mel))