From 19eb939ccfa213ce8540d9777d4c6b425501f01f Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 28 Jul 2022 00:23:35 -0600 Subject: [PATCH] gd perplexity # Conflicts: # codes/trainer/eval/music_diffusion_fid.py --- codes/models/diffusion/gaussian_diffusion.py | 42 +++++++++++++++++++- 1 file changed, 41 insertions(+), 1 deletion(-) diff --git a/codes/models/diffusion/gaussian_diffusion.py b/codes/models/diffusion/gaussian_diffusion.py index 4a57ebb7..ee74e03e 100644 --- a/codes/models/diffusion/gaussian_diffusion.py +++ b/codes/models/diffusion/gaussian_diffusion.py @@ -510,7 +510,7 @@ class GaussianDiffusion: cond_fn, out, x, t, model_kwargs=model_kwargs ) sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise - return {"sample": sample, "pred_xstart": out["pred_xstart"]} + return {"sample": sample, "pred_xstart": out["pred_xstart"], "mean": out["mean"], "log_variance": out["log_variance"]} def p_sample_loop( self, @@ -654,6 +654,46 @@ class GaussianDiffusion: img = model_driven_out + guidance_driven_out return img + def p_sample_loop_for_perplexity( + self, + model, + truth, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + ): + if device is None: + device = next(model.parameters()).device + shape = truth.shape + if noise is None: + noise = th.randn(*shape, device=device) + indices = list(range(self.num_timesteps))[::-1] + + img = noise + perp = 1 + for i in tqdm(indices): + t = th.tensor([i] * shape[0], device=device) + with th.no_grad(): + out = self.p_sample( + model, + img, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + ) + mean = out["mean"] + std = out["log_variance"].exp().sqrt() + q = self.q_sample(truth, t, noise=noise) + err = out - q + prob = (err - mean) / std + perp = prob * perp + return perp + def ddim_sample( self, model,