diff --git a/codes/models/diffusion/gaussian_diffusion.py b/codes/models/diffusion/gaussian_diffusion.py index 5e5687e3..6689922e 100644 --- a/codes/models/diffusion/gaussian_diffusion.py +++ b/codes/models/diffusion/gaussian_diffusion.py @@ -697,9 +697,8 @@ class GaussianDiffusion: logperp = torch.log(normpdf(err, mean, var)) / self.num_timesteps + logperp # Remove -infs, which do happen pretty regularly (and penalize them proportionately). - num_infs = torch.isinf(logperp).sum() - logperp[torch.isinf(logperp)] = torch.max(logperp) * num_infs * 2 - print(f'Num infs: : {num_infs}') # probably should just log this. + logperp[torch.isinf(logperp)] = torch.max(logperp) * 2 + print(f'Num infs: : {torch.isinf(logperp).sum()}') # probably should just log this. return -logperp.mean() def ddim_sample(