default to use input for conditioning & add preprocessed input to GDI

pull/2/head
James Betker 2022-07-18 17:01:19 +07:00
parent 1b4d9567f3
commit 83a4ef4149
2 changed files with 13 additions and 3 deletions

@ -132,6 +132,7 @@ class TransformerDiffusion(nn.Module):
self.enable_fp16 = use_fp16
self.resolution_steps = resolution_steps
self.max_window = max_window
self.preprocessed = None
self.time_embed = nn.Sequential(
linear(time_embed_dim, time_embed_dim),
@ -189,15 +190,20 @@ class TransformerDiffusion(nn.Module):
s_prior = x_prior[:,:,start:start+self.max_window]
s_prior = F.interpolate(s_prior, scale_factor=.25, mode='linear', align_corners=True)
s_prior = F.interpolate(s_prior, size=(s.shape[-1],), mode='linear', align_corners=True)
return s, s_prior, torch.tensor([resolution] * x.shape[0], dtype=torch.long, device=x.device)
self.preprocessed = (s_prior, resolution)
return s
def forward(self, x, timesteps, x_prior=None, resolution=None, conditioning_input=None, conditioning_free=False):
unused_params = []
conditioning_input = x_prior if conditioning_input is None else conditioning_input
h = x
if resolution is None:
h, h_sub, resolution = self.input_to_random_resolution_and_window(x, x_prior)
else:
assert self.preprocessed is not None, 'Preprocessing function not called.'
h = x
h_sub, resolution = self.preprocessed
self.preprocessed = None
else:
h_sub = F.interpolate(x_prior, scale_factor=4, mode='linear', align_corners=True)
assert h.shape == h_sub.shape, f'{h.shape} {h_sub.shape}'

@ -47,6 +47,7 @@ class GaussianDiffusionInjector(Injector):
self.deterministic_sampler = DeterministicSampler(self.diffusion, opt_get(opt, ['deterministic_sampler_expected_batch_size'], 2048), env)
self.causal_mode = opt_get(opt, ['causal_mode'], False)
self.causal_slope_range = opt_get(opt, ['causal_slope_range'], [1,8])
self.preprocess_fn = opt_get(opt, ['preprocess_fn'], None)
k = 0
if 'channel_balancer_proportion' in opt.keys():
@ -88,6 +89,9 @@ class GaussianDiffusionInjector(Injector):
sampler = self.schedule_sampler
self.deterministic_sampler.reset() # Keep this reset whenever it is not being used, so it is ready to use automatically.
model_inputs = {k: state[v] if isinstance(v, str) else v for k, v in self.model_input_keys.items()}
if self.preprocess_fn is not None:
hq = getattr(gen, self.preprocess_fn)(hq, **model_inputs)
t, weights = sampler.sample(hq.shape[0], hq.device)
if self.causal_mode:
cs, ce = self.causal_slope_range