padding
This commit is contained in:
parent
581bc7ac5c
commit
40ba802104
|
@ -22,7 +22,7 @@ from models.diffusion.nn import (
|
||||||
)
|
)
|
||||||
from models.lucidrains.x_transformers import Encoder
|
from models.lucidrains.x_transformers import Encoder
|
||||||
from trainer.networks import register_model
|
from trainer.networks import register_model
|
||||||
from utils.util import checkpoint, print_network
|
from utils.util import checkpoint, print_network, ceil_multiple
|
||||||
|
|
||||||
|
|
||||||
class TimestepBlock(nn.Module):
|
class TimestepBlock(nn.Module):
|
||||||
|
@ -646,14 +646,13 @@ class UNetMusicModel(nn.Module):
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x, timesteps, y, conditioning_free=False):
|
def forward(self, x, timesteps, y, conditioning_free=False):
|
||||||
"""
|
orig_x_shape = x.shape[-1]
|
||||||
Apply the model to an input batch.
|
cm = ceil_multiple(x.shape[-1], 16)
|
||||||
|
if cm != 0:
|
||||||
|
pc = (cm - x.shape[-1]) / x.shape[-1]
|
||||||
|
x = F.pad(x, (0, cm - x.shape[-1]))
|
||||||
|
y = F.pad(y.permute(0,2,1), (0, int(pc * y.shape[-1]))).permute(0,2,1)
|
||||||
|
|
||||||
:param x: an [N x C x ...] Tensor of inputs.
|
|
||||||
:param timesteps: a 1-D batch of timesteps.
|
|
||||||
:param y: a batch of guidance latents from a quantizer
|
|
||||||
:return: an [N x C x ...] Tensor of outputs.
|
|
||||||
"""
|
|
||||||
unused_params = []
|
unused_params = []
|
||||||
hs = []
|
hs = []
|
||||||
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
|
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
|
||||||
|
@ -692,7 +691,8 @@ class UNetMusicModel(nn.Module):
|
||||||
extraneous_addition = extraneous_addition + p.mean()
|
extraneous_addition = extraneous_addition + p.mean()
|
||||||
h = h + extraneous_addition * 0
|
h = h + extraneous_addition * 0
|
||||||
|
|
||||||
return self.out(h)
|
out = self.out(h)
|
||||||
|
return out[:, :, :orig_x_shape]
|
||||||
|
|
||||||
|
|
||||||
class UNetMusicModelWithQuantizer(nn.Module):
|
class UNetMusicModelWithQuantizer(nn.Module):
|
||||||
|
@ -746,8 +746,8 @@ def register_unet_diffusion_music_codes(opt_net, opt):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
clip = torch.randn(2, 256, 400)
|
clip = torch.randn(2, 256, 782)
|
||||||
cond = torch.randn(2, 256, 400)
|
cond = torch.randn(2, 256, 782)
|
||||||
ts = torch.LongTensor([600, 600])
|
ts = torch.LongTensor([600, 600])
|
||||||
model = UNetMusicModelWithQuantizer(in_channels=256, out_channels=512, model_channels=640, num_res_blocks=3, input_vec_dim=1024,
|
model = UNetMusicModelWithQuantizer(in_channels=256, out_channels=512, model_channels=640, num_res_blocks=3, input_vec_dim=1024,
|
||||||
attention_resolutions=(2,4), channel_mult=(1,2,3), dims=1,
|
attention_resolutions=(2,4), channel_mult=(1,2,3), dims=1,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user