forked from mrq/DL-Art-School
ddim_with_guidance
This commit is contained in:
parent
20ef9cc6b4
commit
cb24bef406
|
@ -776,6 +776,56 @@ class GaussianDiffusion:
|
|||
final = sample
|
||||
return final["sample"]
|
||||
|
||||
def ddim_sample_loop_with_guidance(
|
||||
self,
|
||||
model,
|
||||
guidance_input,
|
||||
mask,
|
||||
noise=None,
|
||||
clip_denoised=True,
|
||||
causal=False,
|
||||
causal_slope=1,
|
||||
denoised_fn=None,
|
||||
cond_fn=None,
|
||||
model_kwargs=None,
|
||||
eta=0.0,
|
||||
):
|
||||
device = guidance_input.device
|
||||
shape = guidance_input.shape
|
||||
if noise is not None:
|
||||
img = noise
|
||||
else:
|
||||
img = th.randn(*shape, device=device)
|
||||
indices = list(range(self.num_timesteps))[::-1]
|
||||
|
||||
orig_img = img
|
||||
for i in tqdm(indices):
|
||||
t = th.tensor([i] * shape[0], device=device)
|
||||
c_mask = torch.zeros_like(img)
|
||||
if causal:
|
||||
t = causal_timestep_adjustment(t, shape[-1], self.num_timesteps, causal_slope * self._get_scale_ratio(), add_jitter=False).unsqueeze(1)
|
||||
t, c_mask = causal_mask_and_fix(t, self.num_timesteps)
|
||||
t[c_mask] = self.num_timesteps-1
|
||||
c_mask = c_mask.repeat(img.shape[0], img.shape[1], 1)
|
||||
with th.no_grad():
|
||||
out = self.ddim_sample(
|
||||
model,
|
||||
img,
|
||||
t,
|
||||
clip_denoised=clip_denoised,
|
||||
denoised_fn=denoised_fn,
|
||||
cond_fn=cond_fn,
|
||||
model_kwargs=model_kwargs,
|
||||
eta=eta,
|
||||
)
|
||||
model_driven_out = out["sample"] * mask.logical_not()
|
||||
if torch.any(c_mask):
|
||||
model_driven_out[c_mask] = orig_img[c_mask] # For causal diffusion, keep resetting these predictions until they are unmasked.
|
||||
guidance_driven_out = self.q_sample(guidance_input, t, noise=noise) * mask
|
||||
img = model_driven_out + guidance_driven_out
|
||||
orig_img = orig_img
|
||||
return img
|
||||
|
||||
def ddim_sample_loop_progressive(
|
||||
self,
|
||||
model,
|
||||
|
|
Loading…
Reference in New Issue
Block a user