diff --git a/codes/models/audio/music/unet_diffusion_music_codes.py b/codes/models/audio/music/unet_diffusion_music_codes.py index 95927eec..1533b324 100644 --- a/codes/models/audio/music/unet_diffusion_music_codes.py +++ b/codes/models/audio/music/unet_diffusion_music_codes.py @@ -22,7 +22,7 @@ from models.diffusion.nn import ( ) from models.lucidrains.x_transformers import Encoder from trainer.networks import register_model -from utils.util import checkpoint, print_network +from utils.util import checkpoint, print_network, ceil_multiple class TimestepBlock(nn.Module): @@ -646,14 +646,13 @@ class UNetMusicModel(nn.Module): ) def forward(self, x, timesteps, y, conditioning_free=False): - """ - Apply the model to an input batch. + orig_x_shape = x.shape[-1] + cm = ceil_multiple(x.shape[-1], 16) + if cm != 0: + pc = (cm - x.shape[-1]) / x.shape[-1] + x = F.pad(x, (0, cm - x.shape[-1])) + y = F.pad(y.permute(0,2,1), (0, int(pc * y.shape[-1]))).permute(0,2,1) - :param x: an [N x C x ...] Tensor of inputs. - :param timesteps: a 1-D batch of timesteps. - :param y: a batch of guidance latents from a quantizer - :return: an [N x C x ...] Tensor of outputs. - """ unused_params = [] hs = [] emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) @@ -692,7 +691,8 @@ class UNetMusicModel(nn.Module): extraneous_addition = extraneous_addition + p.mean() h = h + extraneous_addition * 0 - return self.out(h) + out = self.out(h) + return out[:, :, :orig_x_shape] class UNetMusicModelWithQuantizer(nn.Module): @@ -746,8 +746,8 @@ def register_unet_diffusion_music_codes(opt_net, opt): if __name__ == '__main__': - clip = torch.randn(2, 256, 400) - cond = torch.randn(2, 256, 400) + clip = torch.randn(2, 256, 782) + cond = torch.randn(2, 256, 782) ts = torch.LongTensor([600, 600]) model = UNetMusicModelWithQuantizer(in_channels=256, out_channels=512, model_channels=640, num_res_blocks=3, input_vec_dim=1024, attention_resolutions=(2,4), channel_mult=(1,2,3), dims=1,