implement causal sampling for standard p_sampling

This commit is contained in:
James Betker 2022-07-09 08:01:03 -06:00
parent 55b9f31825
commit 694400a45b

View File

@ -479,9 +479,12 @@ class GaussianDiffusion:
model_kwargs=model_kwargs, model_kwargs=model_kwargs,
) )
noise = th.randn_like(x) noise = th.randn_like(x)
if len(t.shape) == 1:
nonzero_mask = ( nonzero_mask = (
(t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
) # no noise when t == 0 ) # no noise when t == 0
else:
nonzero_mask = (t != 0).float()
if cond_fn is not None: if cond_fn is not None:
out["mean"] = self.condition_mean( out["mean"] = self.condition_mean(
cond_fn, out, x, t, model_kwargs=model_kwargs cond_fn, out, x, t, model_kwargs=model_kwargs
@ -495,6 +498,8 @@ class GaussianDiffusion:
shape, shape,
noise=None, noise=None,
clip_denoised=True, clip_denoised=True,
causal=False,
causal_slope=1,
denoised_fn=None, denoised_fn=None,
cond_fn=None, cond_fn=None,
model_kwargs=None, model_kwargs=None,
@ -526,6 +531,8 @@ class GaussianDiffusion:
shape, shape,
noise=noise, noise=noise,
clip_denoised=clip_denoised, clip_denoised=clip_denoised,
causal=causal,
causal_slope=causal_slope,
denoised_fn=denoised_fn, denoised_fn=denoised_fn,
cond_fn=cond_fn, cond_fn=cond_fn,
model_kwargs=model_kwargs, model_kwargs=model_kwargs,
@ -541,6 +548,8 @@ class GaussianDiffusion:
shape, shape,
noise=None, noise=None,
clip_denoised=True, clip_denoised=True,
causal=False,
causal_slope=1,
denoised_fn=None, denoised_fn=None,
cond_fn=None, cond_fn=None,
model_kwargs=None, model_kwargs=None,
@ -564,8 +573,15 @@ class GaussianDiffusion:
img = th.randn(*shape, device=device) img = th.randn(*shape, device=device)
indices = list(range(self.num_timesteps))[::-1] indices = list(range(self.num_timesteps))[::-1]
orig_img = img
for i in tqdm(indices): for i in tqdm(indices):
t = th.tensor([i] * shape[0], device=device) t = th.tensor([i] * shape[0], device=device)
mask = torch.zeros_like(img)
if causal:
t = causal_timestep_adjustment(t, shape[-1], self.num_timesteps, causal_slope, add_jitter=False).unsqueeze(1)
mask = t == self.num_timesteps
t[mask] = self.num_timesteps-1
mask = mask.repeat(img.shape[0], img.shape[1], 1)
with th.no_grad(): with th.no_grad():
out = self.p_sample( out = self.p_sample(
model, model,
@ -578,6 +594,8 @@ class GaussianDiffusion:
) )
yield out yield out
img = out["sample"] img = out["sample"]
if torch.any(mask):
img[mask] = orig_img[mask] # For causal diffusion, keep resetting these predictions until they are unmasked.
def p_sample_loop_with_guidance( def p_sample_loop_with_guidance(
self, self,
@ -586,6 +604,8 @@ class GaussianDiffusion:
mask, mask,
noise=None, noise=None,
clip_denoised=True, clip_denoised=True,
causal=False,
causal_slope=1,
denoised_fn=None, denoised_fn=None,
cond_fn=None, cond_fn=None,
model_kwargs=None, model_kwargs=None,
@ -607,6 +627,8 @@ class GaussianDiffusion:
img, img,
t, t,
clip_denoised=clip_denoised, clip_denoised=clip_denoised,
causal=causal,
causal_slope=causal_slope,
denoised_fn=denoised_fn, denoised_fn=denoised_fn,
cond_fn=cond_fn, cond_fn=cond_fn,
model_kwargs=model_kwargs, model_kwargs=model_kwargs,