forked from mrq/DL-Art-School
Fix incorrect projections
This commit is contained in:
parent
c37fc3b4ed
commit
a6397ce84a
|
@ -435,14 +435,14 @@ class Mel2Vec(nn.Module):
|
||||||
|
|
||||||
def forward(self, mel, mask_time_indices=None, return_projections=False):
|
def forward(self, mel, mask_time_indices=None, return_projections=False):
|
||||||
proj = self.input_blocks(mel).permute(0,2,1)
|
proj = self.input_blocks(mel).permute(0,2,1)
|
||||||
proj, _ = self.projector(proj)
|
proj, norm_proj = self.projector(proj)
|
||||||
|
|
||||||
# Mask projections
|
# Mask projections
|
||||||
h = self.apply_masking(proj, mask_time_indices)
|
h = self.apply_masking(proj, mask_time_indices)
|
||||||
h = self.encoder(h)
|
h = self.encoder(h)
|
||||||
|
|
||||||
if return_projections:
|
if return_projections:
|
||||||
return h, proj
|
return h, norm_proj
|
||||||
return h
|
return h
|
||||||
|
|
||||||
|
|
||||||
|
@ -538,7 +538,6 @@ class ContrastiveTrainingWrapper(nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.m2v = Mel2Vec(inner_dim=inner_dim, dropout=dropout, mask_time_prob=mask_time_prob,
|
self.m2v = Mel2Vec(inner_dim=inner_dim, dropout=dropout, mask_time_prob=mask_time_prob,
|
||||||
mask_time_length=mask_time_length, **kwargs)
|
mask_time_length=mask_time_length, **kwargs)
|
||||||
self.dropout_features = nn.Dropout(dropout)
|
|
||||||
self.num_negatives = num_negatives
|
self.num_negatives = num_negatives
|
||||||
self.mask_time_prob = mask_time_prob
|
self.mask_time_prob = mask_time_prob
|
||||||
self.mask_time_length = mask_time_length
|
self.mask_time_length = mask_time_length
|
||||||
|
@ -580,8 +579,6 @@ class ContrastiveTrainingWrapper(nn.Module):
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_grad_norm_parameter_groups(self):
|
def get_grad_norm_parameter_groups(self):
|
||||||
if self.freeze_main_net:
|
|
||||||
return {}
|
|
||||||
groups = {
|
groups = {
|
||||||
'projector': list(self.m2v.input_blocks.parameters()) + list(self.m2v.projector.parameters()),
|
'projector': list(self.m2v.input_blocks.parameters()) + list(self.m2v.projector.parameters()),
|
||||||
'encoder': list(self.m2v.encoder.parameters()),
|
'encoder': list(self.m2v.encoder.parameters()),
|
||||||
|
@ -603,10 +600,8 @@ class ContrastiveTrainingWrapper(nn.Module):
|
||||||
transformer_features = self.project_hid(outputs)
|
transformer_features = self.project_hid(outputs)
|
||||||
|
|
||||||
# 2. quantize all (unmasked) extracted features and project to final vq dim
|
# 2. quantize all (unmasked) extracted features and project to final vq dim
|
||||||
extract_features = self.dropout_features(proj)
|
|
||||||
|
|
||||||
quantized_features, codevector_perplexity = self.quantizer(
|
quantized_features, codevector_perplexity = self.quantizer(
|
||||||
extract_features, mask_time_indices=mask_time_indices
|
proj, mask_time_indices=mask_time_indices
|
||||||
)
|
)
|
||||||
quantized_features = self.project_q(quantized_features)
|
quantized_features = self.project_q(quantized_features)
|
||||||
batch_size, sequence_length, hidden_size = quantized_features.shape
|
batch_size, sequence_length, hidden_size = quantized_features.shape
|
||||||
|
|
Loading…
Reference in New Issue
Block a user