i like this better

This commit is contained in:
James Betker 2022-07-28 02:33:23 -06:00
parent d44ed5d12d
commit cfe907f13f

View File

@ -12,6 +12,7 @@ import random
import numpy as np import numpy as np
import torch import torch
import torch as th import torch as th
from torch.distributions import Normal
from tqdm import tqdm from tqdm import tqdm
from models.diffusion.nn import mean_flat from models.diffusion.nn import mean_flat
@ -366,14 +367,17 @@ class GaussianDiffusion:
return x return x
if self.model_mean_type == ModelMeanType.PREVIOUS_X: if self.model_mean_type == ModelMeanType.PREVIOUS_X:
assert 'why are you doing this?'
pred_xstart = process_xstart( pred_xstart = process_xstart(
self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output) self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output)
) )
model_mean = model_output model_mean = model_output
elif self.model_mean_type in [ModelMeanType.START_X, ModelMeanType.EPSILON]: elif self.model_mean_type in [ModelMeanType.START_X, ModelMeanType.EPSILON]:
if self.model_mean_type == ModelMeanType.START_X: if self.model_mean_type == ModelMeanType.START_X:
assert 'bad boy.'
pred_xstart = process_xstart(model_output) pred_xstart = process_xstart(model_output)
else: else:
eps = model_output
pred_xstart = process_xstart( pred_xstart = process_xstart(
self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output) self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
) )
@ -391,6 +395,7 @@ class GaussianDiffusion:
"variance": model_variance, "variance": model_variance,
"log_variance": model_log_variance, "log_variance": model_log_variance,
"pred_xstart": pred_xstart, "pred_xstart": pred_xstart,
"pred_eps": eps,
} }
def _predict_xstart_from_eps(self, x_t, t, eps): def _predict_xstart_from_eps(self, x_t, t, eps):
@ -510,7 +515,7 @@ class GaussianDiffusion:
cond_fn, out, x, t, model_kwargs=model_kwargs cond_fn, out, x, t, model_kwargs=model_kwargs
) )
sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
return {"sample": sample, "pred_xstart": out["pred_xstart"], "mean": out["mean"], "log_variance": out["log_variance"]} return {"sample": sample, "pred_xstart": out["pred_xstart"], "pred_eps": out["pred_eps"], "mean": out["mean"], "log_variance": out["log_variance"]}
def p_sample_loop( def p_sample_loop(
self, self,
@ -673,7 +678,8 @@ class GaussianDiffusion:
indices = list(range(self.num_timesteps))[::-1] indices = list(range(self.num_timesteps))[::-1]
img = noise img = noise
logperp = 1 #perp = self.num_timesteps
logperp = 0
for i in tqdm(indices): for i in tqdm(indices):
t = th.tensor([i] * shape[0], device=device) t = th.tensor([i] * shape[0], device=device)
with th.no_grad(): with th.no_grad():
@ -686,20 +692,16 @@ class GaussianDiffusion:
cond_fn=cond_fn, cond_fn=cond_fn,
model_kwargs=model_kwargs, model_kwargs=model_kwargs,
) )
mean = out["mean"] eps = out["pred_eps"]
var = out["log_variance"].exp() err = noise - eps
q = self.q_sample(truth, t, noise=noise)
err = out["sample"] - q
def normpdf(x, mean, var):
denom = (2 * math.pi * var)**.5
num = torch.exp(-(x-mean)**2/(2*var))
return num / denom
logperp = torch.log(normpdf(err, mean, var)) / self.num_timesteps + logperp m = Normal(torch.tensor([0.0]), torch.tensor([1.0]))
# Remove -infs, which do happen pretty regularly (and penalize them proportionately). nprobs = m.cdf(-err.abs().cpu()) * 2
logperp[torch.isinf(logperp)] = torch.max(logperp) * 2 logperp = torch.log(nprobs) / self.num_timesteps + logperp
print(f'Num infs: : {torch.isinf(logperp).sum()}') # probably should just log this. #perp = nprobs * perp
return -logperp.mean() print(f'Num infs: : {torch.isinf(logperp).sum()}') # probably should just log this separately.
logperp[torch.isinf(logperp)] = logperp.max() * 2
return -logperp
def ddim_sample( def ddim_sample(
self, self,