|
|
|
@ -178,12 +178,12 @@ class TransformerDiffusion(nn.Module):
|
|
|
|
|
"""
|
|
|
|
|
resolution = randrange(0, self.resolution_steps)
|
|
|
|
|
resolution_scale = 2 ** resolution
|
|
|
|
|
s = F.interpolate(x, scale_factor=1/resolution_scale, mode='linear', align_corners=True)
|
|
|
|
|
s = F.interpolate(x, scale_factor=1/resolution_scale, mode='nearest')
|
|
|
|
|
s_diff = s.shape[-1] - self.max_window
|
|
|
|
|
if s_diff > 1:
|
|
|
|
|
start = randrange(0, s_diff)
|
|
|
|
|
s = s[:,:,start:start+self.max_window]
|
|
|
|
|
s_prior = F.interpolate(s, scale_factor=.25, mode='linear', align_corners=True)
|
|
|
|
|
s_prior = F.interpolate(s, scale_factor=.25, mode='nearest')
|
|
|
|
|
s_prior = F.interpolate(s_prior, size=(s.shape[-1],), mode='linear', align_corners=True)
|
|
|
|
|
self.preprocessed = (s_prior, torch.tensor([resolution] * x.shape[0], dtype=torch.long, device=x.device))
|
|
|
|
|
return s
|
|
|
|
|