probably too harsh on ninfs

pull/2/head
James Betker 2022-07-28 01:33:20 +07:00
parent 4509cfc705
commit d44ed5d12d
1 changed files with 2 additions and 3 deletions

@ -697,9 +697,8 @@ class GaussianDiffusion:
logperp = torch.log(normpdf(err, mean, var)) / self.num_timesteps + logperp
# Remove -infs, which do happen pretty regularly (and penalize them proportionately).
num_infs = torch.isinf(logperp).sum()
logperp[torch.isinf(logperp)] = torch.max(logperp) * num_infs * 2
print(f'Num infs: : {num_infs}') # probably should just log this.
logperp[torch.isinf(logperp)] = torch.max(logperp) * 2
print(f'Num infs: : {torch.isinf(logperp).sum()}') # probably should just log this.
return -logperp.mean()
def ddim_sample(