transition to nearest interpolation mode for downsampling
This commit is contained in:
parent
7b3fc79737
commit
b0e3be0a17
|
@ -178,12 +178,12 @@ class TransformerDiffusion(nn.Module):
|
|||
"""
|
||||
resolution = randrange(0, self.resolution_steps)
|
||||
resolution_scale = 2 ** resolution
|
||||
s = F.interpolate(x, scale_factor=1/resolution_scale, mode='linear', align_corners=True)
|
||||
s = F.interpolate(x, scale_factor=1/resolution_scale, mode='nearest')
|
||||
s_diff = s.shape[-1] - self.max_window
|
||||
if s_diff > 1:
|
||||
start = randrange(0, s_diff)
|
||||
s = s[:,:,start:start+self.max_window]
|
||||
s_prior = F.interpolate(s, scale_factor=.25, mode='linear', align_corners=True)
|
||||
s_prior = F.interpolate(s, scale_factor=.25, mode='nearest')
|
||||
s_prior = F.interpolate(s_prior, size=(s.shape[-1],), mode='linear', align_corners=True)
|
||||
self.preprocessed = (s_prior, torch.tensor([resolution] * x.shape[0], dtype=torch.long, device=x.device))
|
||||
return s
|
||||
|
|
|
@ -232,7 +232,7 @@ class MusicDiffusionFid(evaluator.Evaluator):
|
|||
mel_norm = normalize_torch_mel(mel)
|
||||
#mel_norm = mel_norm[:,:,:448*4] # restricts first stage to optimal training window.
|
||||
conditioning = mel_norm[:,:,:1200]
|
||||
downsampled = F.interpolate(mel_norm, scale_factor=1/16, mode='linear', align_corners=True)
|
||||
downsampled = F.interpolate(mel_norm, scale_factor=1/16, mode='nearest')
|
||||
stage1_shape = (1, 256, downsampled.shape[-1]*4)
|
||||
sampler = self.diffuser.ddim_sample_loop if self.ddim else self.diffuser.p_sample_loop
|
||||
# Chain super-sampling using 2 stages.
|
||||
|
|
Loading…
Reference in New Issue
Block a user