forked from mrq/DL-Art-School
whoops
This commit is contained in:
parent
560b83e770
commit
1e1bbe1a27
|
@ -663,6 +663,7 @@ class ContrastiveTrainingWrapper(nn.Module):
|
|||
def forward(self, mel, inp_lengths=None):
|
||||
mel = mel[:, :, :-1] # The MEL computation always pads with 1, throwing off optimal tensor math.
|
||||
features_shape = (mel.shape[0], mel.shape[-1]//self.m2v.dim_reduction_mult)
|
||||
orig_mel = mel
|
||||
|
||||
# Frequency masking
|
||||
freq_mask_width = int(random.random() * self.freq_mask_percent * mel.shape[1])
|
||||
|
@ -731,7 +732,7 @@ class ContrastiveTrainingWrapper(nn.Module):
|
|||
|
||||
if self.reconstruction:
|
||||
reconstruction = self.reconstruction_net(quantized_features.permute(0,2,1))
|
||||
reconstruction_loss = F.mse_loss(reconstruction, mel)
|
||||
reconstruction_loss = F.mse_loss(reconstruction, orig_mel)
|
||||
return contrastive_loss, diversity_loss, reconstruction_loss
|
||||
|
||||
return contrastive_loss, diversity_loss
|
||||
|
|
Loading…
Reference in New Issue
Block a user