diff --git a/codes/models/diffusion/gaussian_diffusion.py b/codes/models/diffusion/gaussian_diffusion.py index 6689922e..ae145143 100644 --- a/codes/models/diffusion/gaussian_diffusion.py +++ b/codes/models/diffusion/gaussian_diffusion.py @@ -12,6 +12,7 @@ import random import numpy as np import torch import torch as th +from torch.distributions import Normal from tqdm import tqdm from models.diffusion.nn import mean_flat @@ -366,14 +367,17 @@ class GaussianDiffusion: return x if self.model_mean_type == ModelMeanType.PREVIOUS_X: + assert 'why are you doing this?' pred_xstart = process_xstart( self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output) ) model_mean = model_output elif self.model_mean_type in [ModelMeanType.START_X, ModelMeanType.EPSILON]: if self.model_mean_type == ModelMeanType.START_X: + assert 'bad boy.' pred_xstart = process_xstart(model_output) else: + eps = model_output pred_xstart = process_xstart( self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output) ) @@ -391,6 +395,7 @@ class GaussianDiffusion: "variance": model_variance, "log_variance": model_log_variance, "pred_xstart": pred_xstart, + "pred_eps": eps, } def _predict_xstart_from_eps(self, x_t, t, eps): @@ -510,7 +515,7 @@ class GaussianDiffusion: cond_fn, out, x, t, model_kwargs=model_kwargs ) sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise - return {"sample": sample, "pred_xstart": out["pred_xstart"], "mean": out["mean"], "log_variance": out["log_variance"]} + return {"sample": sample, "pred_xstart": out["pred_xstart"], "pred_eps": out["pred_eps"], "mean": out["mean"], "log_variance": out["log_variance"]} def p_sample_loop( self, @@ -673,7 +678,8 @@ class GaussianDiffusion: indices = list(range(self.num_timesteps))[::-1] img = noise - logperp = 1 + #perp = self.num_timesteps + logperp = 0 for i in tqdm(indices): t = th.tensor([i] * shape[0], device=device) with th.no_grad(): @@ -686,20 +692,16 @@ class GaussianDiffusion: cond_fn=cond_fn, model_kwargs=model_kwargs, ) - mean = out["mean"] - var = out["log_variance"].exp() - q = self.q_sample(truth, t, noise=noise) - err = out["sample"] - q - def normpdf(x, mean, var): - denom = (2 * math.pi * var)**.5 - num = torch.exp(-(x-mean)**2/(2*var)) - return num / denom + eps = out["pred_eps"] + err = noise - eps - logperp = torch.log(normpdf(err, mean, var)) / self.num_timesteps + logperp - # Remove -infs, which do happen pretty regularly (and penalize them proportionately). - logperp[torch.isinf(logperp)] = torch.max(logperp) * 2 - print(f'Num infs: : {torch.isinf(logperp).sum()}') # probably should just log this. - return -logperp.mean() + m = Normal(torch.tensor([0.0]), torch.tensor([1.0])) + nprobs = m.cdf(-err.abs().cpu()) * 2 + logperp = torch.log(nprobs) / self.num_timesteps + logperp + #perp = nprobs * perp + print(f'Num infs: : {torch.isinf(logperp).sum()}') # probably should just log this separately. + logperp[torch.isinf(logperp)] = logperp.max() * 2 + return -logperp def ddim_sample( self,