implement causal sampling for standard p_sampling

This commit is contained in:
James Betker 2022-07-09 08:01:03 -06:00
parent 55b9f31825
commit 694400a45b

View File

@ -479,9 +479,12 @@ class GaussianDiffusion:
model_kwargs=model_kwargs,
)
noise = th.randn_like(x)
nonzero_mask = (
(t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
) # no noise when t == 0
if len(t.shape) == 1:
nonzero_mask = (
(t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
) # no noise when t == 0
else:
nonzero_mask = (t != 0).float()
if cond_fn is not None:
out["mean"] = self.condition_mean(
cond_fn, out, x, t, model_kwargs=model_kwargs
@ -495,6 +498,8 @@ class GaussianDiffusion:
shape,
noise=None,
clip_denoised=True,
causal=False,
causal_slope=1,
denoised_fn=None,
cond_fn=None,
model_kwargs=None,
@ -526,6 +531,8 @@ class GaussianDiffusion:
shape,
noise=noise,
clip_denoised=clip_denoised,
causal=causal,
causal_slope=causal_slope,
denoised_fn=denoised_fn,
cond_fn=cond_fn,
model_kwargs=model_kwargs,
@ -541,6 +548,8 @@ class GaussianDiffusion:
shape,
noise=None,
clip_denoised=True,
causal=False,
causal_slope=1,
denoised_fn=None,
cond_fn=None,
model_kwargs=None,
@ -564,8 +573,15 @@ class GaussianDiffusion:
img = th.randn(*shape, device=device)
indices = list(range(self.num_timesteps))[::-1]
orig_img = img
for i in tqdm(indices):
t = th.tensor([i] * shape[0], device=device)
mask = torch.zeros_like(img)
if causal:
t = causal_timestep_adjustment(t, shape[-1], self.num_timesteps, causal_slope, add_jitter=False).unsqueeze(1)
mask = t == self.num_timesteps
t[mask] = self.num_timesteps-1
mask = mask.repeat(img.shape[0], img.shape[1], 1)
with th.no_grad():
out = self.p_sample(
model,
@ -578,6 +594,8 @@ class GaussianDiffusion:
)
yield out
img = out["sample"]
if torch.any(mask):
img[mask] = orig_img[mask] # For causal diffusion, keep resetting these predictions until they are unmasked.
def p_sample_loop_with_guidance(
self,
@ -586,6 +604,8 @@ class GaussianDiffusion:
mask,
noise=None,
clip_denoised=True,
causal=False,
causal_slope=1,
denoised_fn=None,
cond_fn=None,
model_kwargs=None,
@ -607,6 +627,8 @@ class GaussianDiffusion:
img,
t,
clip_denoised=clip_denoised,
causal=causal,
causal_slope=causal_slope,
denoised_fn=denoised_fn,
cond_fn=cond_fn,
model_kwargs=model_kwargs,