This commit is contained in:
James Betker 2022-06-03 12:09:59 -06:00
parent 581bc7ac5c
commit 40ba802104

View File

@ -22,7 +22,7 @@ from models.diffusion.nn import (
)
from models.lucidrains.x_transformers import Encoder
from trainer.networks import register_model
from utils.util import checkpoint, print_network
from utils.util import checkpoint, print_network, ceil_multiple
class TimestepBlock(nn.Module):
@ -646,14 +646,13 @@ class UNetMusicModel(nn.Module):
)
def forward(self, x, timesteps, y, conditioning_free=False):
"""
Apply the model to an input batch.
orig_x_shape = x.shape[-1]
cm = ceil_multiple(x.shape[-1], 16)
if cm != 0:
pc = (cm - x.shape[-1]) / x.shape[-1]
x = F.pad(x, (0, cm - x.shape[-1]))
y = F.pad(y.permute(0,2,1), (0, int(pc * y.shape[-1]))).permute(0,2,1)
:param x: an [N x C x ...] Tensor of inputs.
:param timesteps: a 1-D batch of timesteps.
:param y: a batch of guidance latents from a quantizer
:return: an [N x C x ...] Tensor of outputs.
"""
unused_params = []
hs = []
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
@ -692,7 +691,8 @@ class UNetMusicModel(nn.Module):
extraneous_addition = extraneous_addition + p.mean()
h = h + extraneous_addition * 0
return self.out(h)
out = self.out(h)
return out[:, :, :orig_x_shape]
class UNetMusicModelWithQuantizer(nn.Module):
@ -746,8 +746,8 @@ def register_unet_diffusion_music_codes(opt_net, opt):
if __name__ == '__main__':
clip = torch.randn(2, 256, 400)
cond = torch.randn(2, 256, 400)
clip = torch.randn(2, 256, 782)
cond = torch.randn(2, 256, 782)
ts = torch.LongTensor([600, 600])
model = UNetMusicModelWithQuantizer(in_channels=256, out_channels=512, model_channels=640, num_res_blocks=3, input_vec_dim=1024,
attention_resolutions=(2,4), channel_mult=(1,2,3), dims=1,