diff --git a/codes/models/audio/music/transformer_diffusion13.py b/codes/models/audio/music/transformer_diffusion13.py index 3fdcdce9..05ff8fe4 100644 --- a/codes/models/audio/music/transformer_diffusion13.py +++ b/codes/models/audio/music/transformer_diffusion13.py @@ -178,12 +178,12 @@ class TransformerDiffusion(nn.Module): """ resolution = randrange(0, self.resolution_steps) resolution_scale = 2 ** resolution - s = F.interpolate(x, scale_factor=1/resolution_scale, mode='linear', align_corners=True) + s = F.interpolate(x, scale_factor=1/resolution_scale, mode='nearest') s_diff = s.shape[-1] - self.max_window if s_diff > 1: start = randrange(0, s_diff) s = s[:,:,start:start+self.max_window] - s_prior = F.interpolate(s, scale_factor=.25, mode='linear', align_corners=True) + s_prior = F.interpolate(s, scale_factor=.25, mode='nearest') s_prior = F.interpolate(s_prior, size=(s.shape[-1],), mode='linear', align_corners=True) self.preprocessed = (s_prior, torch.tensor([resolution] * x.shape[0], dtype=torch.long, device=x.device)) return s diff --git a/codes/trainer/eval/music_diffusion_fid.py b/codes/trainer/eval/music_diffusion_fid.py index 50c03dc1..8b8c10fa 100644 --- a/codes/trainer/eval/music_diffusion_fid.py +++ b/codes/trainer/eval/music_diffusion_fid.py @@ -232,7 +232,7 @@ class MusicDiffusionFid(evaluator.Evaluator): mel_norm = normalize_torch_mel(mel) #mel_norm = mel_norm[:,:,:448*4] # restricts first stage to optimal training window. conditioning = mel_norm[:,:,:1200] - downsampled = F.interpolate(mel_norm, scale_factor=1/16, mode='linear', align_corners=True) + downsampled = F.interpolate(mel_norm, scale_factor=1/16, mode='nearest') stage1_shape = (1, 256, downsampled.shape[-1]*4) sampler = self.diffuser.ddim_sample_loop if self.ddim else self.diffuser.p_sample_loop # Chain super-sampling using 2 stages.