From 6e57eaa1861be8f6851bc05d492a3605fa14c48b Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 9 Jun 2022 21:52:57 -0600 Subject: [PATCH] fix bug --- codes/models/diffusion/gaussian_diffusion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/codes/models/diffusion/gaussian_diffusion.py b/codes/models/diffusion/gaussian_diffusion.py index 274c0520..77751a72 100644 --- a/codes/models/diffusion/gaussian_diffusion.py +++ b/codes/models/diffusion/gaussian_diffusion.py @@ -787,7 +787,7 @@ class GaussianDiffusion: # At the first timestep return the decoder NLL, # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t)) - output = th.where((t == 0), decoder_nll, kl) + output = th.where((t == 0).view(-1, 1, 1), decoder_nll, kl) return {"output": output, "pred_xstart": out["pred_xstart"]} def training_losses(self, model, x_start, t, model_kwargs=None, noise=None, channel_balancing_fn=None):