i like this better
This commit is contained in:
parent
d44ed5d12d
commit
cfe907f13f
|
@ -12,6 +12,7 @@ import random
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch as th
|
import torch as th
|
||||||
|
from torch.distributions import Normal
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from models.diffusion.nn import mean_flat
|
from models.diffusion.nn import mean_flat
|
||||||
|
@ -366,14 +367,17 @@ class GaussianDiffusion:
|
||||||
return x
|
return x
|
||||||
|
|
||||||
if self.model_mean_type == ModelMeanType.PREVIOUS_X:
|
if self.model_mean_type == ModelMeanType.PREVIOUS_X:
|
||||||
|
assert 'why are you doing this?'
|
||||||
pred_xstart = process_xstart(
|
pred_xstart = process_xstart(
|
||||||
self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output)
|
self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output)
|
||||||
)
|
)
|
||||||
model_mean = model_output
|
model_mean = model_output
|
||||||
elif self.model_mean_type in [ModelMeanType.START_X, ModelMeanType.EPSILON]:
|
elif self.model_mean_type in [ModelMeanType.START_X, ModelMeanType.EPSILON]:
|
||||||
if self.model_mean_type == ModelMeanType.START_X:
|
if self.model_mean_type == ModelMeanType.START_X:
|
||||||
|
assert 'bad boy.'
|
||||||
pred_xstart = process_xstart(model_output)
|
pred_xstart = process_xstart(model_output)
|
||||||
else:
|
else:
|
||||||
|
eps = model_output
|
||||||
pred_xstart = process_xstart(
|
pred_xstart = process_xstart(
|
||||||
self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
|
self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
|
||||||
)
|
)
|
||||||
|
@ -391,6 +395,7 @@ class GaussianDiffusion:
|
||||||
"variance": model_variance,
|
"variance": model_variance,
|
||||||
"log_variance": model_log_variance,
|
"log_variance": model_log_variance,
|
||||||
"pred_xstart": pred_xstart,
|
"pred_xstart": pred_xstart,
|
||||||
|
"pred_eps": eps,
|
||||||
}
|
}
|
||||||
|
|
||||||
def _predict_xstart_from_eps(self, x_t, t, eps):
|
def _predict_xstart_from_eps(self, x_t, t, eps):
|
||||||
|
@ -510,7 +515,7 @@ class GaussianDiffusion:
|
||||||
cond_fn, out, x, t, model_kwargs=model_kwargs
|
cond_fn, out, x, t, model_kwargs=model_kwargs
|
||||||
)
|
)
|
||||||
sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
|
sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
|
||||||
return {"sample": sample, "pred_xstart": out["pred_xstart"], "mean": out["mean"], "log_variance": out["log_variance"]}
|
return {"sample": sample, "pred_xstart": out["pred_xstart"], "pred_eps": out["pred_eps"], "mean": out["mean"], "log_variance": out["log_variance"]}
|
||||||
|
|
||||||
def p_sample_loop(
|
def p_sample_loop(
|
||||||
self,
|
self,
|
||||||
|
@ -673,7 +678,8 @@ class GaussianDiffusion:
|
||||||
indices = list(range(self.num_timesteps))[::-1]
|
indices = list(range(self.num_timesteps))[::-1]
|
||||||
|
|
||||||
img = noise
|
img = noise
|
||||||
logperp = 1
|
#perp = self.num_timesteps
|
||||||
|
logperp = 0
|
||||||
for i in tqdm(indices):
|
for i in tqdm(indices):
|
||||||
t = th.tensor([i] * shape[0], device=device)
|
t = th.tensor([i] * shape[0], device=device)
|
||||||
with th.no_grad():
|
with th.no_grad():
|
||||||
|
@ -686,20 +692,16 @@ class GaussianDiffusion:
|
||||||
cond_fn=cond_fn,
|
cond_fn=cond_fn,
|
||||||
model_kwargs=model_kwargs,
|
model_kwargs=model_kwargs,
|
||||||
)
|
)
|
||||||
mean = out["mean"]
|
eps = out["pred_eps"]
|
||||||
var = out["log_variance"].exp()
|
err = noise - eps
|
||||||
q = self.q_sample(truth, t, noise=noise)
|
|
||||||
err = out["sample"] - q
|
|
||||||
def normpdf(x, mean, var):
|
|
||||||
denom = (2 * math.pi * var)**.5
|
|
||||||
num = torch.exp(-(x-mean)**2/(2*var))
|
|
||||||
return num / denom
|
|
||||||
|
|
||||||
logperp = torch.log(normpdf(err, mean, var)) / self.num_timesteps + logperp
|
m = Normal(torch.tensor([0.0]), torch.tensor([1.0]))
|
||||||
# Remove -infs, which do happen pretty regularly (and penalize them proportionately).
|
nprobs = m.cdf(-err.abs().cpu()) * 2
|
||||||
logperp[torch.isinf(logperp)] = torch.max(logperp) * 2
|
logperp = torch.log(nprobs) / self.num_timesteps + logperp
|
||||||
print(f'Num infs: : {torch.isinf(logperp).sum()}') # probably should just log this.
|
#perp = nprobs * perp
|
||||||
return -logperp.mean()
|
print(f'Num infs: : {torch.isinf(logperp).sum()}') # probably should just log this separately.
|
||||||
|
logperp[torch.isinf(logperp)] = logperp.max() * 2
|
||||||
|
return -logperp
|
||||||
|
|
||||||
def ddim_sample(
|
def ddim_sample(
|
||||||
self,
|
self,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user