gd perplexity

# Conflicts:
#	codes/trainer/eval/music_diffusion_fid.py
pull/2/head
James Betker 2022-07-28 00:23:35 +07:00
parent a1bbde8a43
commit 19eb939ccf
1 changed files with 41 additions and 1 deletions

@ -510,7 +510,7 @@ class GaussianDiffusion:
cond_fn, out, x, t, model_kwargs=model_kwargs
)
sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
return {"sample": sample, "pred_xstart": out["pred_xstart"]}
return {"sample": sample, "pred_xstart": out["pred_xstart"], "mean": out["mean"], "log_variance": out["log_variance"]}
def p_sample_loop(
self,
@ -654,6 +654,46 @@ class GaussianDiffusion:
img = model_driven_out + guidance_driven_out
return img
def p_sample_loop_for_perplexity(
self,
model,
truth,
noise=None,
clip_denoised=True,
denoised_fn=None,
cond_fn=None,
model_kwargs=None,
device=None,
):
if device is None:
device = next(model.parameters()).device
shape = truth.shape
if noise is None:
noise = th.randn(*shape, device=device)
indices = list(range(self.num_timesteps))[::-1]
img = noise
perp = 1
for i in tqdm(indices):
t = th.tensor([i] * shape[0], device=device)
with th.no_grad():
out = self.p_sample(
model,
img,
t,
clip_denoised=clip_denoised,
denoised_fn=denoised_fn,
cond_fn=cond_fn,
model_kwargs=model_kwargs,
)
mean = out["mean"]
std = out["log_variance"].exp().sqrt()
q = self.q_sample(truth, t, noise=noise)
err = out - q
prob = (err - mean) / std
perp = prob * perp
return perp
def ddim_sample(
self,
model,