From 6b1cfe8e6673b380db2a032c2e6582bb5eb7b294 Mon Sep 17 00:00:00 2001 From: James Betker Date: Tue, 19 Jul 2022 11:14:20 -0600 Subject: [PATCH] ugh --- codes/models/audio/music/transformer_diffusion13.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/codes/models/audio/music/transformer_diffusion13.py b/codes/models/audio/music/transformer_diffusion13.py index 9dd67410..6cc315ac 100644 --- a/codes/models/audio/music/transformer_diffusion13.py +++ b/codes/models/audio/music/transformer_diffusion13.py @@ -174,7 +174,7 @@ class TransformerDiffusion(nn.Module): assert x.shape == x_prior.shape, f'{x.shape} {x_prior.shape}' resolution = randrange(1, self.resolution_steps) resolution_scale = 2 ** resolution - s = F.interpolate(x, scale_factor=1/resolution_scale, mode='nearest', align_corners=True) + s = F.interpolate(x, scale_factor=1/resolution_scale, mode='linear', align_corners=True) s_prior = F.interpolate(x_prior, scale_factor=1/resolution_scale, mode='linear', align_corners=True) s_diff = s.shape[-1] - self.max_window if s_diff > 1: @@ -265,5 +265,5 @@ def remove_conditioning(sd_path): if __name__ == '__main__': - remove_conditioning('X:\\dlas\\experiments\\train_music_diffusion_multilevel_sr\\models\\12000_generator.pth') + remove_conditioning('X:\\dlas\\experiments\\train_music_diffusion_multilevel_sr_pre\\models\\12500_generator.pth') test_tfd()