i like this better

This commit is contained in:
James Betker 2022-07-28 02:33:23 -06:00
parent d44ed5d12d
commit cfe907f13f

View File

@ -12,6 +12,7 @@ import random
import numpy as np
import torch
import torch as th
from torch.distributions import Normal
from tqdm import tqdm
from models.diffusion.nn import mean_flat
@ -366,14 +367,17 @@ class GaussianDiffusion:
return x
if self.model_mean_type == ModelMeanType.PREVIOUS_X:
assert 'why are you doing this?'
pred_xstart = process_xstart(
self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output)
)
model_mean = model_output
elif self.model_mean_type in [ModelMeanType.START_X, ModelMeanType.EPSILON]:
if self.model_mean_type == ModelMeanType.START_X:
assert 'bad boy.'
pred_xstart = process_xstart(model_output)
else:
eps = model_output
pred_xstart = process_xstart(
self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
)
@ -391,6 +395,7 @@ class GaussianDiffusion:
"variance": model_variance,
"log_variance": model_log_variance,
"pred_xstart": pred_xstart,
"pred_eps": eps,
}
def _predict_xstart_from_eps(self, x_t, t, eps):
@ -510,7 +515,7 @@ class GaussianDiffusion:
cond_fn, out, x, t, model_kwargs=model_kwargs
)
sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
return {"sample": sample, "pred_xstart": out["pred_xstart"], "mean": out["mean"], "log_variance": out["log_variance"]}
return {"sample": sample, "pred_xstart": out["pred_xstart"], "pred_eps": out["pred_eps"], "mean": out["mean"], "log_variance": out["log_variance"]}
def p_sample_loop(
self,
@ -673,7 +678,8 @@ class GaussianDiffusion:
indices = list(range(self.num_timesteps))[::-1]
img = noise
logperp = 1
#perp = self.num_timesteps
logperp = 0
for i in tqdm(indices):
t = th.tensor([i] * shape[0], device=device)
with th.no_grad():
@ -686,20 +692,16 @@ class GaussianDiffusion:
cond_fn=cond_fn,
model_kwargs=model_kwargs,
)
mean = out["mean"]
var = out["log_variance"].exp()
q = self.q_sample(truth, t, noise=noise)
err = out["sample"] - q
def normpdf(x, mean, var):
denom = (2 * math.pi * var)**.5
num = torch.exp(-(x-mean)**2/(2*var))
return num / denom
eps = out["pred_eps"]
err = noise - eps
logperp = torch.log(normpdf(err, mean, var)) / self.num_timesteps + logperp
# Remove -infs, which do happen pretty regularly (and penalize them proportionately).
logperp[torch.isinf(logperp)] = torch.max(logperp) * 2
print(f'Num infs: : {torch.isinf(logperp).sum()}') # probably should just log this.
return -logperp.mean()
m = Normal(torch.tensor([0.0]), torch.tensor([1.0]))
nprobs = m.cdf(-err.abs().cpu()) * 2
logperp = torch.log(nprobs) / self.num_timesteps + logperp
#perp = nprobs * perp
print(f'Num infs: : {torch.isinf(logperp).sum()}') # probably should just log this separately.
logperp[torch.isinf(logperp)] = logperp.max() * 2
return -logperp
def ddim_sample(
self,