transition to nearest interpolation mode for downsampling

pull/2/head
James Betker 2022-07-20 10:56:17 +07:00
parent 7b3fc79737
commit b0e3be0a17
2 changed files with 3 additions and 3 deletions

@ -178,12 +178,12 @@ class TransformerDiffusion(nn.Module):
"""
resolution = randrange(0, self.resolution_steps)
resolution_scale = 2 ** resolution
s = F.interpolate(x, scale_factor=1/resolution_scale, mode='linear', align_corners=True)
s = F.interpolate(x, scale_factor=1/resolution_scale, mode='nearest')
s_diff = s.shape[-1] - self.max_window
if s_diff > 1:
start = randrange(0, s_diff)
s = s[:,:,start:start+self.max_window]
s_prior = F.interpolate(s, scale_factor=.25, mode='linear', align_corners=True)
s_prior = F.interpolate(s, scale_factor=.25, mode='nearest')
s_prior = F.interpolate(s_prior, size=(s.shape[-1],), mode='linear', align_corners=True)
self.preprocessed = (s_prior, torch.tensor([resolution] * x.shape[0], dtype=torch.long, device=x.device))
return s

@ -232,7 +232,7 @@ class MusicDiffusionFid(evaluator.Evaluator):
mel_norm = normalize_torch_mel(mel)
#mel_norm = mel_norm[:,:,:448*4] # restricts first stage to optimal training window.
conditioning = mel_norm[:,:,:1200]
downsampled = F.interpolate(mel_norm, scale_factor=1/16, mode='linear', align_corners=True)
downsampled = F.interpolate(mel_norm, scale_factor=1/16, mode='nearest')
stage1_shape = (1, 256, downsampled.shape[-1]*4)
sampler = self.diffuser.ddim_sample_loop if self.ddim else self.diffuser.p_sample_loop
# Chain super-sampling using 2 stages.