implement causal sampling for standard p_sampling
This commit is contained in:
parent
55b9f31825
commit
694400a45b
|
@ -479,9 +479,12 @@ class GaussianDiffusion:
|
|||
model_kwargs=model_kwargs,
|
||||
)
|
||||
noise = th.randn_like(x)
|
||||
nonzero_mask = (
|
||||
(t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
|
||||
) # no noise when t == 0
|
||||
if len(t.shape) == 1:
|
||||
nonzero_mask = (
|
||||
(t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
|
||||
) # no noise when t == 0
|
||||
else:
|
||||
nonzero_mask = (t != 0).float()
|
||||
if cond_fn is not None:
|
||||
out["mean"] = self.condition_mean(
|
||||
cond_fn, out, x, t, model_kwargs=model_kwargs
|
||||
|
@ -495,6 +498,8 @@ class GaussianDiffusion:
|
|||
shape,
|
||||
noise=None,
|
||||
clip_denoised=True,
|
||||
causal=False,
|
||||
causal_slope=1,
|
||||
denoised_fn=None,
|
||||
cond_fn=None,
|
||||
model_kwargs=None,
|
||||
|
@ -526,6 +531,8 @@ class GaussianDiffusion:
|
|||
shape,
|
||||
noise=noise,
|
||||
clip_denoised=clip_denoised,
|
||||
causal=causal,
|
||||
causal_slope=causal_slope,
|
||||
denoised_fn=denoised_fn,
|
||||
cond_fn=cond_fn,
|
||||
model_kwargs=model_kwargs,
|
||||
|
@ -541,6 +548,8 @@ class GaussianDiffusion:
|
|||
shape,
|
||||
noise=None,
|
||||
clip_denoised=True,
|
||||
causal=False,
|
||||
causal_slope=1,
|
||||
denoised_fn=None,
|
||||
cond_fn=None,
|
||||
model_kwargs=None,
|
||||
|
@ -564,8 +573,15 @@ class GaussianDiffusion:
|
|||
img = th.randn(*shape, device=device)
|
||||
indices = list(range(self.num_timesteps))[::-1]
|
||||
|
||||
orig_img = img
|
||||
for i in tqdm(indices):
|
||||
t = th.tensor([i] * shape[0], device=device)
|
||||
mask = torch.zeros_like(img)
|
||||
if causal:
|
||||
t = causal_timestep_adjustment(t, shape[-1], self.num_timesteps, causal_slope, add_jitter=False).unsqueeze(1)
|
||||
mask = t == self.num_timesteps
|
||||
t[mask] = self.num_timesteps-1
|
||||
mask = mask.repeat(img.shape[0], img.shape[1], 1)
|
||||
with th.no_grad():
|
||||
out = self.p_sample(
|
||||
model,
|
||||
|
@ -578,6 +594,8 @@ class GaussianDiffusion:
|
|||
)
|
||||
yield out
|
||||
img = out["sample"]
|
||||
if torch.any(mask):
|
||||
img[mask] = orig_img[mask] # For causal diffusion, keep resetting these predictions until they are unmasked.
|
||||
|
||||
def p_sample_loop_with_guidance(
|
||||
self,
|
||||
|
@ -586,6 +604,8 @@ class GaussianDiffusion:
|
|||
mask,
|
||||
noise=None,
|
||||
clip_denoised=True,
|
||||
causal=False,
|
||||
causal_slope=1,
|
||||
denoised_fn=None,
|
||||
cond_fn=None,
|
||||
model_kwargs=None,
|
||||
|
@ -607,6 +627,8 @@ class GaussianDiffusion:
|
|||
img,
|
||||
t,
|
||||
clip_denoised=clip_denoised,
|
||||
causal=causal,
|
||||
causal_slope=causal_slope,
|
||||
denoised_fn=denoised_fn,
|
||||
cond_fn=cond_fn,
|
||||
model_kwargs=model_kwargs,
|
||||
|
|
Loading…
Reference in New Issue
Block a user