fix size issues

This commit is contained in:
James Betker 2022-05-31 21:23:26 -06:00
parent 8a1b8e3e62
commit e8cb93a4e9

View File

@ -4,7 +4,7 @@ import torch.nn.functional as F
from models.arch_util import zero_module from models.arch_util import zero_module
from trainer.networks import register_model from trainer.networks import register_model
from utils.util import checkpoint from utils.util import checkpoint, ceil_multiple
class Downsample(nn.Module): class Downsample(nn.Module):
@ -195,6 +195,11 @@ class MusicQuantizer(nn.Module):
return self.quantizer.get_codes(proj) return self.quantizer.get_codes(proj)
def forward(self, mel): def forward(self, mel):
orig_mel = mel
cm = ceil_multiple(mel.shape[-1], 4)
if cm != 0:
mel = F.pad(mel, (0,cm-mel.shape[-1]))
h = self.down(mel) h = self.down(mel)
h = self.encoder(h) h = self.encoder(h)
h = self.enc_norm(h.permute(0,2,1)) h = self.enc_norm(h.permute(0,2,1))
@ -202,7 +207,8 @@ class MusicQuantizer(nn.Module):
h = self.decoder(codevectors.permute(0,2,1)) h = self.decoder(codevectors.permute(0,2,1))
reconstructed = self.up(h) reconstructed = self.up(h)
mse = F.mse_loss(reconstructed, mel) reconstructed = reconstructed[:, :, :orig_mel.shape[-1]]
mse = F.mse_loss(reconstructed, orig_mel)
diversity = (self.quantizer.num_codevectors - perplexity) / self.quantizer.num_codevectors diversity = (self.quantizer.num_codevectors - perplexity) / self.quantizer.num_codevectors
self.log_codes(codes) self.log_codes(codes)
@ -236,5 +242,5 @@ def register_music_quantizer(opt_net, opt):
if __name__ == '__main__': if __name__ == '__main__':
model = MusicQuantizer() model = MusicQuantizer()
mel = torch.randn((2,256,200)) mel = torch.randn((2,256,782))
model(mel) model(mel)