ddim_with_guidance

This commit is contained in:
James Betker 2022-07-17 18:24:43 -06:00
parent 20ef9cc6b4
commit cb24bef406

View File

@ -776,6 +776,56 @@ class GaussianDiffusion:
final = sample
return final["sample"]
def ddim_sample_loop_with_guidance(
self,
model,
guidance_input,
mask,
noise=None,
clip_denoised=True,
causal=False,
causal_slope=1,
denoised_fn=None,
cond_fn=None,
model_kwargs=None,
eta=0.0,
):
device = guidance_input.device
shape = guidance_input.shape
if noise is not None:
img = noise
else:
img = th.randn(*shape, device=device)
indices = list(range(self.num_timesteps))[::-1]
orig_img = img
for i in tqdm(indices):
t = th.tensor([i] * shape[0], device=device)
c_mask = torch.zeros_like(img)
if causal:
t = causal_timestep_adjustment(t, shape[-1], self.num_timesteps, causal_slope * self._get_scale_ratio(), add_jitter=False).unsqueeze(1)
t, c_mask = causal_mask_and_fix(t, self.num_timesteps)
t[c_mask] = self.num_timesteps-1
c_mask = c_mask.repeat(img.shape[0], img.shape[1], 1)
with th.no_grad():
out = self.ddim_sample(
model,
img,
t,
clip_denoised=clip_denoised,
denoised_fn=denoised_fn,
cond_fn=cond_fn,
model_kwargs=model_kwargs,
eta=eta,
)
model_driven_out = out["sample"] * mask.logical_not()
if torch.any(c_mask):
model_driven_out[c_mask] = orig_img[c_mask] # For causal diffusion, keep resetting these predictions until they are unmasked.
guidance_driven_out = self.q_sample(guidance_input, t, noise=noise) * mask
img = model_driven_out + guidance_driven_out
orig_img = orig_img
return img
def ddim_sample_loop_progressive(
self,
model,