diff --git a/codes/models/audio/mel2vec.py b/codes/models/audio/mel2vec.py index 93bf59f2..098fafc0 100644 --- a/codes/models/audio/mel2vec.py +++ b/codes/models/audio/mel2vec.py @@ -663,6 +663,7 @@ class ContrastiveTrainingWrapper(nn.Module): def forward(self, mel, inp_lengths=None): mel = mel[:, :, :-1] # The MEL computation always pads with 1, throwing off optimal tensor math. features_shape = (mel.shape[0], mel.shape[-1]//self.m2v.dim_reduction_mult) + orig_mel = mel # Frequency masking freq_mask_width = int(random.random() * self.freq_mask_percent * mel.shape[1]) @@ -731,7 +732,7 @@ class ContrastiveTrainingWrapper(nn.Module): if self.reconstruction: reconstruction = self.reconstruction_net(quantized_features.permute(0,2,1)) - reconstruction_loss = F.mse_loss(reconstruction, mel) + reconstruction_loss = F.mse_loss(reconstruction, orig_mel) return contrastive_loss, diversity_loss, reconstruction_loss return contrastive_loss, diversity_loss