forked from mrq/DL-Art-School
fix size issues
This commit is contained in:
parent
8a1b8e3e62
commit
e8cb93a4e9
|
@ -4,7 +4,7 @@ import torch.nn.functional as F
|
||||||
|
|
||||||
from models.arch_util import zero_module
|
from models.arch_util import zero_module
|
||||||
from trainer.networks import register_model
|
from trainer.networks import register_model
|
||||||
from utils.util import checkpoint
|
from utils.util import checkpoint, ceil_multiple
|
||||||
|
|
||||||
|
|
||||||
class Downsample(nn.Module):
|
class Downsample(nn.Module):
|
||||||
|
@ -195,6 +195,11 @@ class MusicQuantizer(nn.Module):
|
||||||
return self.quantizer.get_codes(proj)
|
return self.quantizer.get_codes(proj)
|
||||||
|
|
||||||
def forward(self, mel):
|
def forward(self, mel):
|
||||||
|
orig_mel = mel
|
||||||
|
cm = ceil_multiple(mel.shape[-1], 4)
|
||||||
|
if cm != 0:
|
||||||
|
mel = F.pad(mel, (0,cm-mel.shape[-1]))
|
||||||
|
|
||||||
h = self.down(mel)
|
h = self.down(mel)
|
||||||
h = self.encoder(h)
|
h = self.encoder(h)
|
||||||
h = self.enc_norm(h.permute(0,2,1))
|
h = self.enc_norm(h.permute(0,2,1))
|
||||||
|
@ -202,7 +207,8 @@ class MusicQuantizer(nn.Module):
|
||||||
h = self.decoder(codevectors.permute(0,2,1))
|
h = self.decoder(codevectors.permute(0,2,1))
|
||||||
reconstructed = self.up(h)
|
reconstructed = self.up(h)
|
||||||
|
|
||||||
mse = F.mse_loss(reconstructed, mel)
|
reconstructed = reconstructed[:, :, :orig_mel.shape[-1]]
|
||||||
|
mse = F.mse_loss(reconstructed, orig_mel)
|
||||||
diversity = (self.quantizer.num_codevectors - perplexity) / self.quantizer.num_codevectors
|
diversity = (self.quantizer.num_codevectors - perplexity) / self.quantizer.num_codevectors
|
||||||
|
|
||||||
self.log_codes(codes)
|
self.log_codes(codes)
|
||||||
|
@ -236,5 +242,5 @@ def register_music_quantizer(opt_net, opt):
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
model = MusicQuantizer()
|
model = MusicQuantizer()
|
||||||
mel = torch.randn((2,256,200))
|
mel = torch.randn((2,256,782))
|
||||||
model(mel)
|
model(mel)
|
Loading…
Reference in New Issue
Block a user