diff --git a/codes/models/diffusion/gaussian_diffusion.py b/codes/models/diffusion/gaussian_diffusion.py index 610591fd..678adcba 100644 --- a/codes/models/diffusion/gaussian_diffusion.py +++ b/codes/models/diffusion/gaussian_diffusion.py @@ -603,6 +603,7 @@ class GaussianDiffusion: img = out["sample"] if torch.any(mask): img[mask] = orig_img[mask] # For causal diffusion, keep resetting these predictions until they are unmasked. + orig_img = img def p_sample_loop_with_guidance( self, @@ -834,6 +835,7 @@ class GaussianDiffusion: img = out["sample"] if torch.any(mask): img[mask] = orig_img[mask] # For causal diffusion, keep resetting these predictions until they are unmasked. + orig_img = orig_img def _vb_terms_bpd( self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None