another fix for causal diffusion in inference

This commit is contained in:
James Betker 2022-07-09 15:29:47 -06:00
parent 8657d4d060
commit f28cefdfe2

View File

@ -603,6 +603,7 @@ class GaussianDiffusion:
img = out["sample"]
if torch.any(mask):
img[mask] = orig_img[mask] # For causal diffusion, keep resetting these predictions until they are unmasked.
orig_img = img
def p_sample_loop_with_guidance(
self,
@ -834,6 +835,7 @@ class GaussianDiffusion:
img = out["sample"]
if torch.any(mask):
img[mask] = orig_img[mask] # For causal diffusion, keep resetting these predictions until they are unmasked.
orig_img = orig_img
def _vb_terms_bpd(
self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None