From e8cb93a4e9e4ae9f419c89415537d316a4b2071d Mon Sep 17 00:00:00 2001 From: James Betker Date: Tue, 31 May 2022 21:23:26 -0600 Subject: [PATCH] fix size issues --- codes/models/audio/music/music_quantizer.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/codes/models/audio/music/music_quantizer.py b/codes/models/audio/music/music_quantizer.py index fa95ad62..7cfe611d 100644 --- a/codes/models/audio/music/music_quantizer.py +++ b/codes/models/audio/music/music_quantizer.py @@ -4,7 +4,7 @@ import torch.nn.functional as F from models.arch_util import zero_module from trainer.networks import register_model -from utils.util import checkpoint +from utils.util import checkpoint, ceil_multiple class Downsample(nn.Module): @@ -195,6 +195,11 @@ class MusicQuantizer(nn.Module): return self.quantizer.get_codes(proj) def forward(self, mel): + orig_mel = mel + cm = ceil_multiple(mel.shape[-1], 4) + if cm != 0: + mel = F.pad(mel, (0,cm-mel.shape[-1])) + h = self.down(mel) h = self.encoder(h) h = self.enc_norm(h.permute(0,2,1)) @@ -202,7 +207,8 @@ class MusicQuantizer(nn.Module): h = self.decoder(codevectors.permute(0,2,1)) reconstructed = self.up(h) - mse = F.mse_loss(reconstructed, mel) + reconstructed = reconstructed[:, :, :orig_mel.shape[-1]] + mse = F.mse_loss(reconstructed, orig_mel) diversity = (self.quantizer.num_codevectors - perplexity) / self.quantizer.num_codevectors self.log_codes(codes) @@ -236,5 +242,5 @@ def register_music_quantizer(opt_net, opt): if __name__ == '__main__': model = MusicQuantizer() - mel = torch.randn((2,256,200)) + mel = torch.randn((2,256,782)) model(mel) \ No newline at end of file