From 694400a45ba55f3e159dc8790a017cb95761b4c4 Mon Sep 17 00:00:00 2001 From: James Betker Date: Sat, 9 Jul 2022 08:01:03 -0600 Subject: [PATCH] implement causal sampling for standard p_sampling --- codes/models/diffusion/gaussian_diffusion.py | 28 +++++++++++++++++--- 1 file changed, 25 insertions(+), 3 deletions(-) diff --git a/codes/models/diffusion/gaussian_diffusion.py b/codes/models/diffusion/gaussian_diffusion.py index 35c4a10a..6ad457c0 100644 --- a/codes/models/diffusion/gaussian_diffusion.py +++ b/codes/models/diffusion/gaussian_diffusion.py @@ -479,9 +479,12 @@ class GaussianDiffusion: model_kwargs=model_kwargs, ) noise = th.randn_like(x) - nonzero_mask = ( - (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) - ) # no noise when t == 0 + if len(t.shape) == 1: + nonzero_mask = ( + (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) + ) # no noise when t == 0 + else: + nonzero_mask = (t != 0).float() if cond_fn is not None: out["mean"] = self.condition_mean( cond_fn, out, x, t, model_kwargs=model_kwargs @@ -495,6 +498,8 @@ class GaussianDiffusion: shape, noise=None, clip_denoised=True, + causal=False, + causal_slope=1, denoised_fn=None, cond_fn=None, model_kwargs=None, @@ -526,6 +531,8 @@ class GaussianDiffusion: shape, noise=noise, clip_denoised=clip_denoised, + causal=causal, + causal_slope=causal_slope, denoised_fn=denoised_fn, cond_fn=cond_fn, model_kwargs=model_kwargs, @@ -541,6 +548,8 @@ class GaussianDiffusion: shape, noise=None, clip_denoised=True, + causal=False, + causal_slope=1, denoised_fn=None, cond_fn=None, model_kwargs=None, @@ -564,8 +573,15 @@ class GaussianDiffusion: img = th.randn(*shape, device=device) indices = list(range(self.num_timesteps))[::-1] + orig_img = img for i in tqdm(indices): t = th.tensor([i] * shape[0], device=device) + mask = torch.zeros_like(img) + if causal: + t = causal_timestep_adjustment(t, shape[-1], self.num_timesteps, causal_slope, add_jitter=False).unsqueeze(1) + mask = t == self.num_timesteps + t[mask] = self.num_timesteps-1 + mask = mask.repeat(img.shape[0], img.shape[1], 1) with th.no_grad(): out = self.p_sample( model, @@ -578,6 +594,8 @@ class GaussianDiffusion: ) yield out img = out["sample"] + if torch.any(mask): + img[mask] = orig_img[mask] # For causal diffusion, keep resetting these predictions until they are unmasked. def p_sample_loop_with_guidance( self, @@ -586,6 +604,8 @@ class GaussianDiffusion: mask, noise=None, clip_denoised=True, + causal=False, + causal_slope=1, denoised_fn=None, cond_fn=None, model_kwargs=None, @@ -607,6 +627,8 @@ class GaussianDiffusion: img, t, clip_denoised=clip_denoised, + causal=causal, + causal_slope=causal_slope, denoised_fn=denoised_fn, cond_fn=cond_fn, model_kwargs=model_kwargs,