forked from mrq/DL-Art-School
default to use input for conditioning & add preprocessed input to GDI
This commit is contained in:
parent
1b4d9567f3
commit
83a4ef4149
|
@ -132,6 +132,7 @@ class TransformerDiffusion(nn.Module):
|
|||
self.enable_fp16 = use_fp16
|
||||
self.resolution_steps = resolution_steps
|
||||
self.max_window = max_window
|
||||
self.preprocessed = None
|
||||
|
||||
self.time_embed = nn.Sequential(
|
||||
linear(time_embed_dim, time_embed_dim),
|
||||
|
@ -189,15 +190,20 @@ class TransformerDiffusion(nn.Module):
|
|||
s_prior = x_prior[:,:,start:start+self.max_window]
|
||||
s_prior = F.interpolate(s_prior, scale_factor=.25, mode='linear', align_corners=True)
|
||||
s_prior = F.interpolate(s_prior, size=(s.shape[-1],), mode='linear', align_corners=True)
|
||||
return s, s_prior, torch.tensor([resolution] * x.shape[0], dtype=torch.long, device=x.device)
|
||||
self.preprocessed = (s_prior, resolution)
|
||||
return s
|
||||
|
||||
def forward(self, x, timesteps, x_prior=None, resolution=None, conditioning_input=None, conditioning_free=False):
|
||||
unused_params = []
|
||||
conditioning_input = x_prior if conditioning_input is None else conditioning_input
|
||||
|
||||
h = x
|
||||
if resolution is None:
|
||||
h, h_sub, resolution = self.input_to_random_resolution_and_window(x, x_prior)
|
||||
else:
|
||||
assert self.preprocessed is not None, 'Preprocessing function not called.'
|
||||
h = x
|
||||
h_sub, resolution = self.preprocessed
|
||||
self.preprocessed = None
|
||||
else:
|
||||
h_sub = F.interpolate(x_prior, scale_factor=4, mode='linear', align_corners=True)
|
||||
assert h.shape == h_sub.shape, f'{h.shape} {h_sub.shape}'
|
||||
|
||||
|
|
|
@ -47,6 +47,7 @@ class GaussianDiffusionInjector(Injector):
|
|||
self.deterministic_sampler = DeterministicSampler(self.diffusion, opt_get(opt, ['deterministic_sampler_expected_batch_size'], 2048), env)
|
||||
self.causal_mode = opt_get(opt, ['causal_mode'], False)
|
||||
self.causal_slope_range = opt_get(opt, ['causal_slope_range'], [1,8])
|
||||
self.preprocess_fn = opt_get(opt, ['preprocess_fn'], None)
|
||||
|
||||
k = 0
|
||||
if 'channel_balancer_proportion' in opt.keys():
|
||||
|
@ -88,6 +89,9 @@ class GaussianDiffusionInjector(Injector):
|
|||
sampler = self.schedule_sampler
|
||||
self.deterministic_sampler.reset() # Keep this reset whenever it is not being used, so it is ready to use automatically.
|
||||
model_inputs = {k: state[v] if isinstance(v, str) else v for k, v in self.model_input_keys.items()}
|
||||
if self.preprocess_fn is not None:
|
||||
hq = getattr(gen, self.preprocess_fn)(hq, **model_inputs)
|
||||
|
||||
t, weights = sampler.sample(hq.shape[0], hq.device)
|
||||
if self.causal_mode:
|
||||
cs, ce = self.causal_slope_range
|
||||
|
|
Loading…
Reference in New Issue
Block a user