From cb24bef4067c20ff118188ee625f9252b17f8fff Mon Sep 17 00:00:00 2001 From: James Betker Date: Sun, 17 Jul 2022 18:24:43 -0600 Subject: [PATCH] ddim_with_guidance --- codes/models/diffusion/gaussian_diffusion.py | 50 ++++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/codes/models/diffusion/gaussian_diffusion.py b/codes/models/diffusion/gaussian_diffusion.py index ed0a8cf9..42ceb144 100644 --- a/codes/models/diffusion/gaussian_diffusion.py +++ b/codes/models/diffusion/gaussian_diffusion.py @@ -776,6 +776,56 @@ class GaussianDiffusion: final = sample return final["sample"] + def ddim_sample_loop_with_guidance( + self, + model, + guidance_input, + mask, + noise=None, + clip_denoised=True, + causal=False, + causal_slope=1, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + eta=0.0, + ): + device = guidance_input.device + shape = guidance_input.shape + if noise is not None: + img = noise + else: + img = th.randn(*shape, device=device) + indices = list(range(self.num_timesteps))[::-1] + + orig_img = img + for i in tqdm(indices): + t = th.tensor([i] * shape[0], device=device) + c_mask = torch.zeros_like(img) + if causal: + t = causal_timestep_adjustment(t, shape[-1], self.num_timesteps, causal_slope * self._get_scale_ratio(), add_jitter=False).unsqueeze(1) + t, c_mask = causal_mask_and_fix(t, self.num_timesteps) + t[c_mask] = self.num_timesteps-1 + c_mask = c_mask.repeat(img.shape[0], img.shape[1], 1) + with th.no_grad(): + out = self.ddim_sample( + model, + img, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + eta=eta, + ) + model_driven_out = out["sample"] * mask.logical_not() + if torch.any(c_mask): + model_driven_out[c_mask] = orig_img[c_mask] # For causal diffusion, keep resetting these predictions until they are unmasked. + guidance_driven_out = self.q_sample(guidance_input, t, noise=noise) * mask + img = model_driven_out + guidance_driven_out + orig_img = orig_img + return img + def ddim_sample_loop_progressive( self, model,