From 10f4a742bd6b27bb3948c7344706164d0c1613b7 Mon Sep 17 00:00:00 2001 From: James Betker Date: Mon, 23 May 2022 08:16:04 -0600 Subject: [PATCH] reintroduce attention masks --- codes/models/audio/mel2vec.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/codes/models/audio/mel2vec.py b/codes/models/audio/mel2vec.py index e9594445..3b11bac2 100644 --- a/codes/models/audio/mel2vec.py +++ b/codes/models/audio/mel2vec.py @@ -570,7 +570,7 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module): class ContrastiveTrainingWrapper(nn.Module): def __init__(self, inner_dim=1024, dropout=.1, mask_time_prob=.65, mask_time_length=6, num_negatives=100, max_gumbel_temperature=2.0, min_gumbel_temperature=.5, gumbel_temperature_decay=.999995, - codebook_size=320, codebook_groups=2, freq_mask_percent=0, + codebook_size=320, codebook_groups=2, freq_mask_percent=0, inp_length_multiplier=256, **kwargs): super().__init__() self.m2v = Mel2Vec(inner_dim=inner_dim, dropout=dropout, mask_time_prob=mask_time_prob, @@ -584,6 +584,7 @@ class ContrastiveTrainingWrapper(nn.Module): self.freq_mask_percent = freq_mask_percent self.quantizer = Wav2Vec2GumbelVectorQuantizer(inner_dim, num_codevector_groups=codebook_groups, num_codevectors_per_group=codebook_size) self.num_losses_record = [] + self.inp_length_factor = inp_length_multiplier # make sure that project_hid & project_q are initialized like normal linear layers self.project_hid = nn.Linear(inner_dim, self.quantizer.codevector_dim) @@ -630,7 +631,7 @@ class ContrastiveTrainingWrapper(nn.Module): codes = self.quantizer.get_codes(proj) return codes - def forward(self, mel): + def forward(self, mel, inp_lengths=None): mel = mel[:, :, :-1] # The MEL computation always pads with 1, throwing off optimal tensor math. # Frequency masking @@ -639,8 +640,15 @@ class ContrastiveTrainingWrapper(nn.Module): freq_start = random.randint(0, mel.shape[1]-freq_mask_width) mel[:, freq_start:freq_start+freq_mask_width] = 0 + # Build input masks from inp_lengths if possible. + attention_mask = torch.ones_like(mel) + if inp_lengths is not None: + inp_lengths = inp_lengths // self.inp_length_factor + for i, l in enumerate(inp_lengths): + attention_mask[i, l:] = 0 + features_shape = (mel.shape[0], mel.shape[-1]//self.m2v.dim_reduction_mult) - mask_time_indices = _compute_mask_indices(features_shape, self.mask_time_prob, self.mask_time_length) + mask_time_indices = _compute_mask_indices(features_shape, self.mask_time_prob, self.mask_time_length, attention_mask=attention_mask) sampled_negative_indices = torch.tensor(_sample_negative_indices(features_shape, self.num_negatives, mask_time_indices=mask_time_indices), device=mel.device) mask_time_indices = torch.tensor(mask_time_indices, device=mel.device)