forked from mrq/DL-Art-School
Fix incorrect projections
This commit is contained in:
parent
c37fc3b4ed
commit
a6397ce84a
|
@ -435,14 +435,14 @@ class Mel2Vec(nn.Module):
|
|||
|
||||
def forward(self, mel, mask_time_indices=None, return_projections=False):
|
||||
proj = self.input_blocks(mel).permute(0,2,1)
|
||||
proj, _ = self.projector(proj)
|
||||
proj, norm_proj = self.projector(proj)
|
||||
|
||||
# Mask projections
|
||||
h = self.apply_masking(proj, mask_time_indices)
|
||||
h = self.encoder(h)
|
||||
|
||||
if return_projections:
|
||||
return h, proj
|
||||
return h, norm_proj
|
||||
return h
|
||||
|
||||
|
||||
|
@ -538,7 +538,6 @@ class ContrastiveTrainingWrapper(nn.Module):
|
|||
super().__init__()
|
||||
self.m2v = Mel2Vec(inner_dim=inner_dim, dropout=dropout, mask_time_prob=mask_time_prob,
|
||||
mask_time_length=mask_time_length, **kwargs)
|
||||
self.dropout_features = nn.Dropout(dropout)
|
||||
self.num_negatives = num_negatives
|
||||
self.mask_time_prob = mask_time_prob
|
||||
self.mask_time_length = mask_time_length
|
||||
|
@ -580,8 +579,6 @@ class ContrastiveTrainingWrapper(nn.Module):
|
|||
)
|
||||
|
||||
def get_grad_norm_parameter_groups(self):
|
||||
if self.freeze_main_net:
|
||||
return {}
|
||||
groups = {
|
||||
'projector': list(self.m2v.input_blocks.parameters()) + list(self.m2v.projector.parameters()),
|
||||
'encoder': list(self.m2v.encoder.parameters()),
|
||||
|
@ -603,10 +600,8 @@ class ContrastiveTrainingWrapper(nn.Module):
|
|||
transformer_features = self.project_hid(outputs)
|
||||
|
||||
# 2. quantize all (unmasked) extracted features and project to final vq dim
|
||||
extract_features = self.dropout_features(proj)
|
||||
|
||||
quantized_features, codevector_perplexity = self.quantizer(
|
||||
extract_features, mask_time_indices=mask_time_indices
|
||||
proj, mask_time_indices=mask_time_indices
|
||||
)
|
||||
quantized_features = self.project_q(quantized_features)
|
||||
batch_size, sequence_length, hidden_size = quantized_features.shape
|
||||
|
|
Loading…
Reference in New Issue
Block a user