gd perplexity
# Conflicts: # codes/trainer/eval/music_diffusion_fid.py
This commit is contained in:
parent
a1bbde8a43
commit
19eb939ccf
|
@ -510,7 +510,7 @@ class GaussianDiffusion:
|
|||
cond_fn, out, x, t, model_kwargs=model_kwargs
|
||||
)
|
||||
sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
|
||||
return {"sample": sample, "pred_xstart": out["pred_xstart"]}
|
||||
return {"sample": sample, "pred_xstart": out["pred_xstart"], "mean": out["mean"], "log_variance": out["log_variance"]}
|
||||
|
||||
def p_sample_loop(
|
||||
self,
|
||||
|
@ -654,6 +654,46 @@ class GaussianDiffusion:
|
|||
img = model_driven_out + guidance_driven_out
|
||||
return img
|
||||
|
||||
def p_sample_loop_for_perplexity(
|
||||
self,
|
||||
model,
|
||||
truth,
|
||||
noise=None,
|
||||
clip_denoised=True,
|
||||
denoised_fn=None,
|
||||
cond_fn=None,
|
||||
model_kwargs=None,
|
||||
device=None,
|
||||
):
|
||||
if device is None:
|
||||
device = next(model.parameters()).device
|
||||
shape = truth.shape
|
||||
if noise is None:
|
||||
noise = th.randn(*shape, device=device)
|
||||
indices = list(range(self.num_timesteps))[::-1]
|
||||
|
||||
img = noise
|
||||
perp = 1
|
||||
for i in tqdm(indices):
|
||||
t = th.tensor([i] * shape[0], device=device)
|
||||
with th.no_grad():
|
||||
out = self.p_sample(
|
||||
model,
|
||||
img,
|
||||
t,
|
||||
clip_denoised=clip_denoised,
|
||||
denoised_fn=denoised_fn,
|
||||
cond_fn=cond_fn,
|
||||
model_kwargs=model_kwargs,
|
||||
)
|
||||
mean = out["mean"]
|
||||
std = out["log_variance"].exp().sqrt()
|
||||
q = self.q_sample(truth, t, noise=noise)
|
||||
err = out - q
|
||||
prob = (err - mean) / std
|
||||
perp = prob * perp
|
||||
return perp
|
||||
|
||||
def ddim_sample(
|
||||
self,
|
||||
model,
|
||||
|
|
Loading…
Reference in New Issue
Block a user