From a6397ce84a785041a672c2177edc8e9d8a09973b Mon Sep 17 00:00:00 2001 From: James Betker Date: Tue, 17 May 2022 16:53:52 -0600 Subject: [PATCH] Fix incorrect projections --- codes/models/audio/mel2vec.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/codes/models/audio/mel2vec.py b/codes/models/audio/mel2vec.py index c8c4c077..c1721a2b 100644 --- a/codes/models/audio/mel2vec.py +++ b/codes/models/audio/mel2vec.py @@ -435,14 +435,14 @@ class Mel2Vec(nn.Module): def forward(self, mel, mask_time_indices=None, return_projections=False): proj = self.input_blocks(mel).permute(0,2,1) - proj, _ = self.projector(proj) + proj, norm_proj = self.projector(proj) # Mask projections h = self.apply_masking(proj, mask_time_indices) h = self.encoder(h) if return_projections: - return h, proj + return h, norm_proj return h @@ -538,7 +538,6 @@ class ContrastiveTrainingWrapper(nn.Module): super().__init__() self.m2v = Mel2Vec(inner_dim=inner_dim, dropout=dropout, mask_time_prob=mask_time_prob, mask_time_length=mask_time_length, **kwargs) - self.dropout_features = nn.Dropout(dropout) self.num_negatives = num_negatives self.mask_time_prob = mask_time_prob self.mask_time_length = mask_time_length @@ -580,8 +579,6 @@ class ContrastiveTrainingWrapper(nn.Module): ) def get_grad_norm_parameter_groups(self): - if self.freeze_main_net: - return {} groups = { 'projector': list(self.m2v.input_blocks.parameters()) + list(self.m2v.projector.parameters()), 'encoder': list(self.m2v.encoder.parameters()), @@ -603,10 +600,8 @@ class ContrastiveTrainingWrapper(nn.Module): transformer_features = self.project_hid(outputs) # 2. quantize all (unmasked) extracted features and project to final vq dim - extract_features = self.dropout_features(proj) - quantized_features, codevector_perplexity = self.quantizer( - extract_features, mask_time_indices=mask_time_indices + proj, mask_time_indices=mask_time_indices ) quantized_features = self.project_q(quantized_features) batch_size, sequence_length, hidden_size = quantized_features.shape