padding
This commit is contained in:
parent
581bc7ac5c
commit
40ba802104
|
@ -22,7 +22,7 @@ from models.diffusion.nn import (
|
|||
)
|
||||
from models.lucidrains.x_transformers import Encoder
|
||||
from trainer.networks import register_model
|
||||
from utils.util import checkpoint, print_network
|
||||
from utils.util import checkpoint, print_network, ceil_multiple
|
||||
|
||||
|
||||
class TimestepBlock(nn.Module):
|
||||
|
@ -646,14 +646,13 @@ class UNetMusicModel(nn.Module):
|
|||
)
|
||||
|
||||
def forward(self, x, timesteps, y, conditioning_free=False):
|
||||
"""
|
||||
Apply the model to an input batch.
|
||||
orig_x_shape = x.shape[-1]
|
||||
cm = ceil_multiple(x.shape[-1], 16)
|
||||
if cm != 0:
|
||||
pc = (cm - x.shape[-1]) / x.shape[-1]
|
||||
x = F.pad(x, (0, cm - x.shape[-1]))
|
||||
y = F.pad(y.permute(0,2,1), (0, int(pc * y.shape[-1]))).permute(0,2,1)
|
||||
|
||||
:param x: an [N x C x ...] Tensor of inputs.
|
||||
:param timesteps: a 1-D batch of timesteps.
|
||||
:param y: a batch of guidance latents from a quantizer
|
||||
:return: an [N x C x ...] Tensor of outputs.
|
||||
"""
|
||||
unused_params = []
|
||||
hs = []
|
||||
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
|
||||
|
@ -692,7 +691,8 @@ class UNetMusicModel(nn.Module):
|
|||
extraneous_addition = extraneous_addition + p.mean()
|
||||
h = h + extraneous_addition * 0
|
||||
|
||||
return self.out(h)
|
||||
out = self.out(h)
|
||||
return out[:, :, :orig_x_shape]
|
||||
|
||||
|
||||
class UNetMusicModelWithQuantizer(nn.Module):
|
||||
|
@ -746,8 +746,8 @@ def register_unet_diffusion_music_codes(opt_net, opt):
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
clip = torch.randn(2, 256, 400)
|
||||
cond = torch.randn(2, 256, 400)
|
||||
clip = torch.randn(2, 256, 782)
|
||||
cond = torch.randn(2, 256, 782)
|
||||
ts = torch.LongTensor([600, 600])
|
||||
model = UNetMusicModelWithQuantizer(in_channels=256, out_channels=512, model_channels=640, num_res_blocks=3, input_vec_dim=1024,
|
||||
attention_resolutions=(2,4), channel_mult=(1,2,3), dims=1,
|
||||
|
|
Loading…
Reference in New Issue
Block a user