This commit is contained in:
James Betker 2022-05-23 12:28:36 -06:00
parent 560b83e770
commit 1e1bbe1a27

View File

@ -663,6 +663,7 @@ class ContrastiveTrainingWrapper(nn.Module):
def forward(self, mel, inp_lengths=None):
mel = mel[:, :, :-1] # The MEL computation always pads with 1, throwing off optimal tensor math.
features_shape = (mel.shape[0], mel.shape[-1]//self.m2v.dim_reduction_mult)
orig_mel = mel
# Frequency masking
freq_mask_width = int(random.random() * self.freq_mask_percent * mel.shape[1])
@ -731,7 +732,7 @@ class ContrastiveTrainingWrapper(nn.Module):
if self.reconstruction:
reconstruction = self.reconstruction_net(quantized_features.permute(0,2,1))
reconstruction_loss = F.mse_loss(reconstruction, mel)
reconstruction_loss = F.mse_loss(reconstruction, orig_mel)
return contrastive_loss, diversity_loss, reconstruction_loss
return contrastive_loss, diversity_loss