|
|
|
@ -132,6 +132,7 @@ class TransformerDiffusion(nn.Module):
|
|
|
|
|
self.enable_fp16 = use_fp16
|
|
|
|
|
self.resolution_steps = resolution_steps
|
|
|
|
|
self.max_window = max_window
|
|
|
|
|
self.preprocessed = None
|
|
|
|
|
|
|
|
|
|
self.time_embed = nn.Sequential(
|
|
|
|
|
linear(time_embed_dim, time_embed_dim),
|
|
|
|
@ -189,15 +190,20 @@ class TransformerDiffusion(nn.Module):
|
|
|
|
|
s_prior = x_prior[:,:,start:start+self.max_window]
|
|
|
|
|
s_prior = F.interpolate(s_prior, scale_factor=.25, mode='linear', align_corners=True)
|
|
|
|
|
s_prior = F.interpolate(s_prior, size=(s.shape[-1],), mode='linear', align_corners=True)
|
|
|
|
|
return s, s_prior, torch.tensor([resolution] * x.shape[0], dtype=torch.long, device=x.device)
|
|
|
|
|
self.preprocessed = (s_prior, resolution)
|
|
|
|
|
return s
|
|
|
|
|
|
|
|
|
|
def forward(self, x, timesteps, x_prior=None, resolution=None, conditioning_input=None, conditioning_free=False):
|
|
|
|
|
unused_params = []
|
|
|
|
|
conditioning_input = x_prior if conditioning_input is None else conditioning_input
|
|
|
|
|
|
|
|
|
|
h = x
|
|
|
|
|
if resolution is None:
|
|
|
|
|
h, h_sub, resolution = self.input_to_random_resolution_and_window(x, x_prior)
|
|
|
|
|
else:
|
|
|
|
|
assert self.preprocessed is not None, 'Preprocessing function not called.'
|
|
|
|
|
h = x
|
|
|
|
|
h_sub, resolution = self.preprocessed
|
|
|
|
|
self.preprocessed = None
|
|
|
|
|
else:
|
|
|
|
|
h_sub = F.interpolate(x_prior, scale_factor=4, mode='linear', align_corners=True)
|
|
|
|
|
assert h.shape == h_sub.shape, f'{h.shape} {h_sub.shape}'
|
|
|
|
|
|
|
|
|
|