Fix incorrect projections

This commit is contained in:
James Betker 2022-05-17 16:53:52 -06:00
parent c37fc3b4ed
commit a6397ce84a

View File

@ -435,14 +435,14 @@ class Mel2Vec(nn.Module):
def forward(self, mel, mask_time_indices=None, return_projections=False):
proj = self.input_blocks(mel).permute(0,2,1)
proj, _ = self.projector(proj)
proj, norm_proj = self.projector(proj)
# Mask projections
h = self.apply_masking(proj, mask_time_indices)
h = self.encoder(h)
if return_projections:
return h, proj
return h, norm_proj
return h
@ -538,7 +538,6 @@ class ContrastiveTrainingWrapper(nn.Module):
super().__init__()
self.m2v = Mel2Vec(inner_dim=inner_dim, dropout=dropout, mask_time_prob=mask_time_prob,
mask_time_length=mask_time_length, **kwargs)
self.dropout_features = nn.Dropout(dropout)
self.num_negatives = num_negatives
self.mask_time_prob = mask_time_prob
self.mask_time_length = mask_time_length
@ -580,8 +579,6 @@ class ContrastiveTrainingWrapper(nn.Module):
)
def get_grad_norm_parameter_groups(self):
if self.freeze_main_net:
return {}
groups = {
'projector': list(self.m2v.input_blocks.parameters()) + list(self.m2v.projector.parameters()),
'encoder': list(self.m2v.encoder.parameters()),
@ -603,10 +600,8 @@ class ContrastiveTrainingWrapper(nn.Module):
transformer_features = self.project_hid(outputs)
# 2. quantize all (unmasked) extracted features and project to final vq dim
extract_features = self.dropout_features(proj)
quantized_features, codevector_perplexity = self.quantizer(
extract_features, mask_time_indices=mask_time_indices
proj, mask_time_indices=mask_time_indices
)
quantized_features = self.project_q(quantized_features)
batch_size, sequence_length, hidden_size = quantized_features.shape