forked from mrq/DL-Art-School
ddim_with_guidance
This commit is contained in:
parent
20ef9cc6b4
commit
cb24bef406
|
@ -776,6 +776,56 @@ class GaussianDiffusion:
|
||||||
final = sample
|
final = sample
|
||||||
return final["sample"]
|
return final["sample"]
|
||||||
|
|
||||||
|
def ddim_sample_loop_with_guidance(
|
||||||
|
self,
|
||||||
|
model,
|
||||||
|
guidance_input,
|
||||||
|
mask,
|
||||||
|
noise=None,
|
||||||
|
clip_denoised=True,
|
||||||
|
causal=False,
|
||||||
|
causal_slope=1,
|
||||||
|
denoised_fn=None,
|
||||||
|
cond_fn=None,
|
||||||
|
model_kwargs=None,
|
||||||
|
eta=0.0,
|
||||||
|
):
|
||||||
|
device = guidance_input.device
|
||||||
|
shape = guidance_input.shape
|
||||||
|
if noise is not None:
|
||||||
|
img = noise
|
||||||
|
else:
|
||||||
|
img = th.randn(*shape, device=device)
|
||||||
|
indices = list(range(self.num_timesteps))[::-1]
|
||||||
|
|
||||||
|
orig_img = img
|
||||||
|
for i in tqdm(indices):
|
||||||
|
t = th.tensor([i] * shape[0], device=device)
|
||||||
|
c_mask = torch.zeros_like(img)
|
||||||
|
if causal:
|
||||||
|
t = causal_timestep_adjustment(t, shape[-1], self.num_timesteps, causal_slope * self._get_scale_ratio(), add_jitter=False).unsqueeze(1)
|
||||||
|
t, c_mask = causal_mask_and_fix(t, self.num_timesteps)
|
||||||
|
t[c_mask] = self.num_timesteps-1
|
||||||
|
c_mask = c_mask.repeat(img.shape[0], img.shape[1], 1)
|
||||||
|
with th.no_grad():
|
||||||
|
out = self.ddim_sample(
|
||||||
|
model,
|
||||||
|
img,
|
||||||
|
t,
|
||||||
|
clip_denoised=clip_denoised,
|
||||||
|
denoised_fn=denoised_fn,
|
||||||
|
cond_fn=cond_fn,
|
||||||
|
model_kwargs=model_kwargs,
|
||||||
|
eta=eta,
|
||||||
|
)
|
||||||
|
model_driven_out = out["sample"] * mask.logical_not()
|
||||||
|
if torch.any(c_mask):
|
||||||
|
model_driven_out[c_mask] = orig_img[c_mask] # For causal diffusion, keep resetting these predictions until they are unmasked.
|
||||||
|
guidance_driven_out = self.q_sample(guidance_input, t, noise=noise) * mask
|
||||||
|
img = model_driven_out + guidance_driven_out
|
||||||
|
orig_img = orig_img
|
||||||
|
return img
|
||||||
|
|
||||||
def ddim_sample_loop_progressive(
|
def ddim_sample_loop_progressive(
|
||||||
self,
|
self,
|
||||||
model,
|
model,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user