diff --git a/codes/models/audio/music/transformer_diffusion13.py b/codes/models/audio/music/transformer_diffusion13.py index 43d2755a..66db3b21 100644 --- a/codes/models/audio/music/transformer_diffusion13.py +++ b/codes/models/audio/music/transformer_diffusion13.py @@ -190,7 +190,7 @@ class TransformerDiffusion(nn.Module): s_prior = x_prior[:,:,start:start+self.max_window] s_prior = F.interpolate(s_prior, scale_factor=.25, mode='linear', align_corners=True) s_prior = F.interpolate(s_prior, size=(s.shape[-1],), mode='linear', align_corners=True) - self.preprocessed = (s_prior, resolution) + self.preprocessed = (s_prior, torch.tensor([resolution] * x.shape[0], dtype=torch.long, device=x.device)) return s def forward(self, x, timesteps, x_prior=None, resolution=None, conditioning_input=None, conditioning_free=False): @@ -249,13 +249,14 @@ def register_transformer_diffusion13(opt_net, opt): def test_tfd(): - clip = torch.randn(2,256,2583) - cond = torch.randn(2,256,2583) + clip = torch.randn(2,256,10336) + cond = torch.randn(2,256,10336) ts = torch.LongTensor([600, 600]) model = TransformerDiffusion(in_channels=256, model_channels=1024, contraction_dim=512, num_heads=512//64, input_vec_dim=256, num_layers=12, dropout=.1) for k in range(100): - model(clip, ts, clip, conditioning_input=cond) + x = model.input_to_random_resolution_and_window(clip, x_prior=clip) + model(x, ts, clip) if __name__ == '__main__':