implement causal sampling for standard p_sampling
This commit is contained in:
parent
55b9f31825
commit
694400a45b
|
@ -479,9 +479,12 @@ class GaussianDiffusion:
|
||||||
model_kwargs=model_kwargs,
|
model_kwargs=model_kwargs,
|
||||||
)
|
)
|
||||||
noise = th.randn_like(x)
|
noise = th.randn_like(x)
|
||||||
|
if len(t.shape) == 1:
|
||||||
nonzero_mask = (
|
nonzero_mask = (
|
||||||
(t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
|
(t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
|
||||||
) # no noise when t == 0
|
) # no noise when t == 0
|
||||||
|
else:
|
||||||
|
nonzero_mask = (t != 0).float()
|
||||||
if cond_fn is not None:
|
if cond_fn is not None:
|
||||||
out["mean"] = self.condition_mean(
|
out["mean"] = self.condition_mean(
|
||||||
cond_fn, out, x, t, model_kwargs=model_kwargs
|
cond_fn, out, x, t, model_kwargs=model_kwargs
|
||||||
|
@ -495,6 +498,8 @@ class GaussianDiffusion:
|
||||||
shape,
|
shape,
|
||||||
noise=None,
|
noise=None,
|
||||||
clip_denoised=True,
|
clip_denoised=True,
|
||||||
|
causal=False,
|
||||||
|
causal_slope=1,
|
||||||
denoised_fn=None,
|
denoised_fn=None,
|
||||||
cond_fn=None,
|
cond_fn=None,
|
||||||
model_kwargs=None,
|
model_kwargs=None,
|
||||||
|
@ -526,6 +531,8 @@ class GaussianDiffusion:
|
||||||
shape,
|
shape,
|
||||||
noise=noise,
|
noise=noise,
|
||||||
clip_denoised=clip_denoised,
|
clip_denoised=clip_denoised,
|
||||||
|
causal=causal,
|
||||||
|
causal_slope=causal_slope,
|
||||||
denoised_fn=denoised_fn,
|
denoised_fn=denoised_fn,
|
||||||
cond_fn=cond_fn,
|
cond_fn=cond_fn,
|
||||||
model_kwargs=model_kwargs,
|
model_kwargs=model_kwargs,
|
||||||
|
@ -541,6 +548,8 @@ class GaussianDiffusion:
|
||||||
shape,
|
shape,
|
||||||
noise=None,
|
noise=None,
|
||||||
clip_denoised=True,
|
clip_denoised=True,
|
||||||
|
causal=False,
|
||||||
|
causal_slope=1,
|
||||||
denoised_fn=None,
|
denoised_fn=None,
|
||||||
cond_fn=None,
|
cond_fn=None,
|
||||||
model_kwargs=None,
|
model_kwargs=None,
|
||||||
|
@ -564,8 +573,15 @@ class GaussianDiffusion:
|
||||||
img = th.randn(*shape, device=device)
|
img = th.randn(*shape, device=device)
|
||||||
indices = list(range(self.num_timesteps))[::-1]
|
indices = list(range(self.num_timesteps))[::-1]
|
||||||
|
|
||||||
|
orig_img = img
|
||||||
for i in tqdm(indices):
|
for i in tqdm(indices):
|
||||||
t = th.tensor([i] * shape[0], device=device)
|
t = th.tensor([i] * shape[0], device=device)
|
||||||
|
mask = torch.zeros_like(img)
|
||||||
|
if causal:
|
||||||
|
t = causal_timestep_adjustment(t, shape[-1], self.num_timesteps, causal_slope, add_jitter=False).unsqueeze(1)
|
||||||
|
mask = t == self.num_timesteps
|
||||||
|
t[mask] = self.num_timesteps-1
|
||||||
|
mask = mask.repeat(img.shape[0], img.shape[1], 1)
|
||||||
with th.no_grad():
|
with th.no_grad():
|
||||||
out = self.p_sample(
|
out = self.p_sample(
|
||||||
model,
|
model,
|
||||||
|
@ -578,6 +594,8 @@ class GaussianDiffusion:
|
||||||
)
|
)
|
||||||
yield out
|
yield out
|
||||||
img = out["sample"]
|
img = out["sample"]
|
||||||
|
if torch.any(mask):
|
||||||
|
img[mask] = orig_img[mask] # For causal diffusion, keep resetting these predictions until they are unmasked.
|
||||||
|
|
||||||
def p_sample_loop_with_guidance(
|
def p_sample_loop_with_guidance(
|
||||||
self,
|
self,
|
||||||
|
@ -586,6 +604,8 @@ class GaussianDiffusion:
|
||||||
mask,
|
mask,
|
||||||
noise=None,
|
noise=None,
|
||||||
clip_denoised=True,
|
clip_denoised=True,
|
||||||
|
causal=False,
|
||||||
|
causal_slope=1,
|
||||||
denoised_fn=None,
|
denoised_fn=None,
|
||||||
cond_fn=None,
|
cond_fn=None,
|
||||||
model_kwargs=None,
|
model_kwargs=None,
|
||||||
|
@ -607,6 +627,8 @@ class GaussianDiffusion:
|
||||||
img,
|
img,
|
||||||
t,
|
t,
|
||||||
clip_denoised=clip_denoised,
|
clip_denoised=clip_denoised,
|
||||||
|
causal=causal,
|
||||||
|
causal_slope=causal_slope,
|
||||||
denoised_fn=denoised_fn,
|
denoised_fn=denoised_fn,
|
||||||
cond_fn=cond_fn,
|
cond_fn=cond_fn,
|
||||||
model_kwargs=model_kwargs,
|
model_kwargs=model_kwargs,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user