|
|
|
@ -510,7 +510,7 @@ class GaussianDiffusion:
|
|
|
|
|
cond_fn, out, x, t, model_kwargs=model_kwargs
|
|
|
|
|
)
|
|
|
|
|
sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
|
|
|
|
|
return {"sample": sample, "pred_xstart": out["pred_xstart"]}
|
|
|
|
|
return {"sample": sample, "pred_xstart": out["pred_xstart"], "mean": out["mean"], "log_variance": out["log_variance"]}
|
|
|
|
|
|
|
|
|
|
def p_sample_loop(
|
|
|
|
|
self,
|
|
|
|
@ -654,6 +654,46 @@ class GaussianDiffusion:
|
|
|
|
|
img = model_driven_out + guidance_driven_out
|
|
|
|
|
return img
|
|
|
|
|
|
|
|
|
|
def p_sample_loop_for_perplexity(
|
|
|
|
|
self,
|
|
|
|
|
model,
|
|
|
|
|
truth,
|
|
|
|
|
noise=None,
|
|
|
|
|
clip_denoised=True,
|
|
|
|
|
denoised_fn=None,
|
|
|
|
|
cond_fn=None,
|
|
|
|
|
model_kwargs=None,
|
|
|
|
|
device=None,
|
|
|
|
|
):
|
|
|
|
|
if device is None:
|
|
|
|
|
device = next(model.parameters()).device
|
|
|
|
|
shape = truth.shape
|
|
|
|
|
if noise is None:
|
|
|
|
|
noise = th.randn(*shape, device=device)
|
|
|
|
|
indices = list(range(self.num_timesteps))[::-1]
|
|
|
|
|
|
|
|
|
|
img = noise
|
|
|
|
|
perp = 1
|
|
|
|
|
for i in tqdm(indices):
|
|
|
|
|
t = th.tensor([i] * shape[0], device=device)
|
|
|
|
|
with th.no_grad():
|
|
|
|
|
out = self.p_sample(
|
|
|
|
|
model,
|
|
|
|
|
img,
|
|
|
|
|
t,
|
|
|
|
|
clip_denoised=clip_denoised,
|
|
|
|
|
denoised_fn=denoised_fn,
|
|
|
|
|
cond_fn=cond_fn,
|
|
|
|
|
model_kwargs=model_kwargs,
|
|
|
|
|
)
|
|
|
|
|
mean = out["mean"]
|
|
|
|
|
std = out["log_variance"].exp().sqrt()
|
|
|
|
|
q = self.q_sample(truth, t, noise=noise)
|
|
|
|
|
err = out - q
|
|
|
|
|
prob = (err - mean) / std
|
|
|
|
|
perp = prob * perp
|
|
|
|
|
return perp
|
|
|
|
|
|
|
|
|
|
def ddim_sample(
|
|
|
|
|
self,
|
|
|
|
|
model,
|
|
|
|
|