forked from mrq/DL-Art-School
i like this better
This commit is contained in:
parent
d44ed5d12d
commit
cfe907f13f
|
@ -12,6 +12,7 @@ import random
|
|||
import numpy as np
|
||||
import torch
|
||||
import torch as th
|
||||
from torch.distributions import Normal
|
||||
from tqdm import tqdm
|
||||
|
||||
from models.diffusion.nn import mean_flat
|
||||
|
@ -366,14 +367,17 @@ class GaussianDiffusion:
|
|||
return x
|
||||
|
||||
if self.model_mean_type == ModelMeanType.PREVIOUS_X:
|
||||
assert 'why are you doing this?'
|
||||
pred_xstart = process_xstart(
|
||||
self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output)
|
||||
)
|
||||
model_mean = model_output
|
||||
elif self.model_mean_type in [ModelMeanType.START_X, ModelMeanType.EPSILON]:
|
||||
if self.model_mean_type == ModelMeanType.START_X:
|
||||
assert 'bad boy.'
|
||||
pred_xstart = process_xstart(model_output)
|
||||
else:
|
||||
eps = model_output
|
||||
pred_xstart = process_xstart(
|
||||
self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
|
||||
)
|
||||
|
@ -391,6 +395,7 @@ class GaussianDiffusion:
|
|||
"variance": model_variance,
|
||||
"log_variance": model_log_variance,
|
||||
"pred_xstart": pred_xstart,
|
||||
"pred_eps": eps,
|
||||
}
|
||||
|
||||
def _predict_xstart_from_eps(self, x_t, t, eps):
|
||||
|
@ -510,7 +515,7 @@ class GaussianDiffusion:
|
|||
cond_fn, out, x, t, model_kwargs=model_kwargs
|
||||
)
|
||||
sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
|
||||
return {"sample": sample, "pred_xstart": out["pred_xstart"], "mean": out["mean"], "log_variance": out["log_variance"]}
|
||||
return {"sample": sample, "pred_xstart": out["pred_xstart"], "pred_eps": out["pred_eps"], "mean": out["mean"], "log_variance": out["log_variance"]}
|
||||
|
||||
def p_sample_loop(
|
||||
self,
|
||||
|
@ -673,7 +678,8 @@ class GaussianDiffusion:
|
|||
indices = list(range(self.num_timesteps))[::-1]
|
||||
|
||||
img = noise
|
||||
logperp = 1
|
||||
#perp = self.num_timesteps
|
||||
logperp = 0
|
||||
for i in tqdm(indices):
|
||||
t = th.tensor([i] * shape[0], device=device)
|
||||
with th.no_grad():
|
||||
|
@ -686,20 +692,16 @@ class GaussianDiffusion:
|
|||
cond_fn=cond_fn,
|
||||
model_kwargs=model_kwargs,
|
||||
)
|
||||
mean = out["mean"]
|
||||
var = out["log_variance"].exp()
|
||||
q = self.q_sample(truth, t, noise=noise)
|
||||
err = out["sample"] - q
|
||||
def normpdf(x, mean, var):
|
||||
denom = (2 * math.pi * var)**.5
|
||||
num = torch.exp(-(x-mean)**2/(2*var))
|
||||
return num / denom
|
||||
eps = out["pred_eps"]
|
||||
err = noise - eps
|
||||
|
||||
logperp = torch.log(normpdf(err, mean, var)) / self.num_timesteps + logperp
|
||||
# Remove -infs, which do happen pretty regularly (and penalize them proportionately).
|
||||
logperp[torch.isinf(logperp)] = torch.max(logperp) * 2
|
||||
print(f'Num infs: : {torch.isinf(logperp).sum()}') # probably should just log this.
|
||||
return -logperp.mean()
|
||||
m = Normal(torch.tensor([0.0]), torch.tensor([1.0]))
|
||||
nprobs = m.cdf(-err.abs().cpu()) * 2
|
||||
logperp = torch.log(nprobs) / self.num_timesteps + logperp
|
||||
#perp = nprobs * perp
|
||||
print(f'Num infs: : {torch.isinf(logperp).sum()}') # probably should just log this separately.
|
||||
logperp[torch.isinf(logperp)] = logperp.max() * 2
|
||||
return -logperp
|
||||
|
||||
def ddim_sample(
|
||||
self,
|
||||
|
|
Loading…
Reference in New Issue
Block a user