forked from mrq/DL-Art-School
gd perplexity
# Conflicts: # codes/trainer/eval/music_diffusion_fid.py
This commit is contained in:
parent
a1bbde8a43
commit
19eb939ccf
|
@ -510,7 +510,7 @@ class GaussianDiffusion:
|
||||||
cond_fn, out, x, t, model_kwargs=model_kwargs
|
cond_fn, out, x, t, model_kwargs=model_kwargs
|
||||||
)
|
)
|
||||||
sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
|
sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
|
||||||
return {"sample": sample, "pred_xstart": out["pred_xstart"]}
|
return {"sample": sample, "pred_xstart": out["pred_xstart"], "mean": out["mean"], "log_variance": out["log_variance"]}
|
||||||
|
|
||||||
def p_sample_loop(
|
def p_sample_loop(
|
||||||
self,
|
self,
|
||||||
|
@ -654,6 +654,46 @@ class GaussianDiffusion:
|
||||||
img = model_driven_out + guidance_driven_out
|
img = model_driven_out + guidance_driven_out
|
||||||
return img
|
return img
|
||||||
|
|
||||||
|
def p_sample_loop_for_perplexity(
|
||||||
|
self,
|
||||||
|
model,
|
||||||
|
truth,
|
||||||
|
noise=None,
|
||||||
|
clip_denoised=True,
|
||||||
|
denoised_fn=None,
|
||||||
|
cond_fn=None,
|
||||||
|
model_kwargs=None,
|
||||||
|
device=None,
|
||||||
|
):
|
||||||
|
if device is None:
|
||||||
|
device = next(model.parameters()).device
|
||||||
|
shape = truth.shape
|
||||||
|
if noise is None:
|
||||||
|
noise = th.randn(*shape, device=device)
|
||||||
|
indices = list(range(self.num_timesteps))[::-1]
|
||||||
|
|
||||||
|
img = noise
|
||||||
|
perp = 1
|
||||||
|
for i in tqdm(indices):
|
||||||
|
t = th.tensor([i] * shape[0], device=device)
|
||||||
|
with th.no_grad():
|
||||||
|
out = self.p_sample(
|
||||||
|
model,
|
||||||
|
img,
|
||||||
|
t,
|
||||||
|
clip_denoised=clip_denoised,
|
||||||
|
denoised_fn=denoised_fn,
|
||||||
|
cond_fn=cond_fn,
|
||||||
|
model_kwargs=model_kwargs,
|
||||||
|
)
|
||||||
|
mean = out["mean"]
|
||||||
|
std = out["log_variance"].exp().sqrt()
|
||||||
|
q = self.q_sample(truth, t, noise=noise)
|
||||||
|
err = out - q
|
||||||
|
prob = (err - mean) / std
|
||||||
|
perp = prob * perp
|
||||||
|
return perp
|
||||||
|
|
||||||
def ddim_sample(
|
def ddim_sample(
|
||||||
self,
|
self,
|
||||||
model,
|
model,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user