From f28cefdfe2b0664772158cea371e5f6d59d1d4ff Mon Sep 17 00:00:00 2001 From: James Betker Date: Sat, 9 Jul 2022 15:29:47 -0600 Subject: [PATCH] another fix for causal diffusion in inference --- codes/models/diffusion/gaussian_diffusion.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/codes/models/diffusion/gaussian_diffusion.py b/codes/models/diffusion/gaussian_diffusion.py index 610591fd..678adcba 100644 --- a/codes/models/diffusion/gaussian_diffusion.py +++ b/codes/models/diffusion/gaussian_diffusion.py @@ -603,6 +603,7 @@ class GaussianDiffusion: img = out["sample"] if torch.any(mask): img[mask] = orig_img[mask] # For causal diffusion, keep resetting these predictions until they are unmasked. + orig_img = img def p_sample_loop_with_guidance( self, @@ -834,6 +835,7 @@ class GaussianDiffusion: img = out["sample"] if torch.any(mask): img[mask] = orig_img[mask] # For causal diffusion, keep resetting these predictions until they are unmasked. + orig_img = orig_img def _vb_terms_bpd( self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None