diff --git a/codes/models/audio/mel2vec.py b/codes/models/audio/mel2vec.py index b654ba4c..e9594445 100644 --- a/codes/models/audio/mel2vec.py +++ b/codes/models/audio/mel2vec.py @@ -1,6 +1,7 @@ import copy import functools import math +import random from typing import Optional, Tuple import numpy as np @@ -569,7 +570,7 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module): class ContrastiveTrainingWrapper(nn.Module): def __init__(self, inner_dim=1024, dropout=.1, mask_time_prob=.65, mask_time_length=6, num_negatives=100, max_gumbel_temperature=2.0, min_gumbel_temperature=.5, gumbel_temperature_decay=.999995, - codebook_size=320, codebook_groups=2, + codebook_size=320, codebook_groups=2, freq_mask_percent=0, **kwargs): super().__init__() self.m2v = Mel2Vec(inner_dim=inner_dim, dropout=dropout, mask_time_prob=mask_time_prob, @@ -580,6 +581,7 @@ class ContrastiveTrainingWrapper(nn.Module): self.max_gumbel_temperature = max_gumbel_temperature self.min_gumbel_temperature = min_gumbel_temperature self.gumbel_temperature_decay = gumbel_temperature_decay + self.freq_mask_percent = freq_mask_percent self.quantizer = Wav2Vec2GumbelVectorQuantizer(inner_dim, num_codevector_groups=codebook_groups, num_codevectors_per_group=codebook_size) self.num_losses_record = [] @@ -631,6 +633,12 @@ class ContrastiveTrainingWrapper(nn.Module): def forward(self, mel): mel = mel[:, :, :-1] # The MEL computation always pads with 1, throwing off optimal tensor math. + # Frequency masking + freq_mask_width = int(random.random() * self.freq_mask_percent * mel.shape[1]) + if freq_mask_width >= 2: + freq_start = random.randint(0, mel.shape[1]-freq_mask_width) + mel[:, freq_start:freq_start+freq_mask_width] = 0 + features_shape = (mel.shape[0], mel.shape[-1]//self.m2v.dim_reduction_mult) mask_time_indices = _compute_mask_indices(features_shape, self.mask_time_prob, self.mask_time_length) sampled_negative_indices = torch.tensor(_sample_negative_indices(features_shape, self.num_negatives, mask_time_indices=mask_time_indices), device=mel.device) @@ -698,6 +706,6 @@ def register_mel2vec(opt_net, opt): if __name__ == '__main__': - model = ContrastiveTrainingWrapper() + model = ContrastiveTrainingWrapper(freq_mask_percent=.5) mel = torch.randn((2,256,401)) print(model(mel)) \ No newline at end of file