fix size issues

This commit is contained in:
James Betker 2022-05-31 21:23:26 -06:00
parent 8a1b8e3e62
commit e8cb93a4e9

View File

@ -4,7 +4,7 @@ import torch.nn.functional as F
from models.arch_util import zero_module
from trainer.networks import register_model
from utils.util import checkpoint
from utils.util import checkpoint, ceil_multiple
class Downsample(nn.Module):
@ -195,6 +195,11 @@ class MusicQuantizer(nn.Module):
return self.quantizer.get_codes(proj)
def forward(self, mel):
orig_mel = mel
cm = ceil_multiple(mel.shape[-1], 4)
if cm != 0:
mel = F.pad(mel, (0,cm-mel.shape[-1]))
h = self.down(mel)
h = self.encoder(h)
h = self.enc_norm(h.permute(0,2,1))
@ -202,7 +207,8 @@ class MusicQuantizer(nn.Module):
h = self.decoder(codevectors.permute(0,2,1))
reconstructed = self.up(h)
mse = F.mse_loss(reconstructed, mel)
reconstructed = reconstructed[:, :, :orig_mel.shape[-1]]
mse = F.mse_loss(reconstructed, orig_mel)
diversity = (self.quantizer.num_codevectors - perplexity) / self.quantizer.num_codevectors
self.log_codes(codes)
@ -236,5 +242,5 @@ def register_music_quantizer(opt_net, opt):
if __name__ == '__main__':
model = MusicQuantizer()
mel = torch.randn((2,256,200))
mel = torch.randn((2,256,782))
model(mel)