diff --git a/codes/models/audio/music/transformer_diffusion13.py b/codes/models/audio/music/transformer_diffusion13.py index 9dd67410..6cc315ac 100644 --- a/codes/models/audio/music/transformer_diffusion13.py +++ b/codes/models/audio/music/transformer_diffusion13.py @@ -174,7 +174,7 @@ class TransformerDiffusion(nn.Module): assert x.shape == x_prior.shape, f'{x.shape} {x_prior.shape}' resolution = randrange(1, self.resolution_steps) resolution_scale = 2 ** resolution - s = F.interpolate(x, scale_factor=1/resolution_scale, mode='nearest', align_corners=True) + s = F.interpolate(x, scale_factor=1/resolution_scale, mode='linear', align_corners=True) s_prior = F.interpolate(x_prior, scale_factor=1/resolution_scale, mode='linear', align_corners=True) s_diff = s.shape[-1] - self.max_window if s_diff > 1: @@ -265,5 +265,5 @@ def remove_conditioning(sd_path): if __name__ == '__main__': - remove_conditioning('X:\\dlas\\experiments\\train_music_diffusion_multilevel_sr\\models\\12000_generator.pth') + remove_conditioning('X:\\dlas\\experiments\\train_music_diffusion_multilevel_sr_pre\\models\\12500_generator.pth') test_tfd()