From b0e3be0a17342c75e7ec3f77f8583baf9b90e16b Mon Sep 17 00:00:00 2001 From: James Betker Date: Wed, 20 Jul 2022 10:56:17 -0600 Subject: [PATCH] transition to nearest interpolation mode for downsampling --- codes/models/audio/music/transformer_diffusion13.py | 4 ++-- codes/trainer/eval/music_diffusion_fid.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/codes/models/audio/music/transformer_diffusion13.py b/codes/models/audio/music/transformer_diffusion13.py index 3fdcdce9..05ff8fe4 100644 --- a/codes/models/audio/music/transformer_diffusion13.py +++ b/codes/models/audio/music/transformer_diffusion13.py @@ -178,12 +178,12 @@ class TransformerDiffusion(nn.Module): """ resolution = randrange(0, self.resolution_steps) resolution_scale = 2 ** resolution - s = F.interpolate(x, scale_factor=1/resolution_scale, mode='linear', align_corners=True) + s = F.interpolate(x, scale_factor=1/resolution_scale, mode='nearest') s_diff = s.shape[-1] - self.max_window if s_diff > 1: start = randrange(0, s_diff) s = s[:,:,start:start+self.max_window] - s_prior = F.interpolate(s, scale_factor=.25, mode='linear', align_corners=True) + s_prior = F.interpolate(s, scale_factor=.25, mode='nearest') s_prior = F.interpolate(s_prior, size=(s.shape[-1],), mode='linear', align_corners=True) self.preprocessed = (s_prior, torch.tensor([resolution] * x.shape[0], dtype=torch.long, device=x.device)) return s diff --git a/codes/trainer/eval/music_diffusion_fid.py b/codes/trainer/eval/music_diffusion_fid.py index 50c03dc1..8b8c10fa 100644 --- a/codes/trainer/eval/music_diffusion_fid.py +++ b/codes/trainer/eval/music_diffusion_fid.py @@ -232,7 +232,7 @@ class MusicDiffusionFid(evaluator.Evaluator): mel_norm = normalize_torch_mel(mel) #mel_norm = mel_norm[:,:,:448*4] # restricts first stage to optimal training window. conditioning = mel_norm[:,:,:1200] - downsampled = F.interpolate(mel_norm, scale_factor=1/16, mode='linear', align_corners=True) + downsampled = F.interpolate(mel_norm, scale_factor=1/16, mode='nearest') stage1_shape = (1, 256, downsampled.shape[-1]*4) sampler = self.diffuser.ddim_sample_loop if self.ddim else self.diffuser.p_sample_loop # Chain super-sampling using 2 stages.