This commit is contained in:
James Betker 2022-06-09 21:52:57 -06:00
parent 07bdd865dc
commit 6e57eaa186

View File

@ -787,7 +787,7 @@ class GaussianDiffusion:
# At the first timestep return the decoder NLL,
# otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
output = th.where((t == 0), decoder_nll, kl)
output = th.where((t == 0).view(-1, 1, 1), decoder_nll, kl)
return {"output": output, "pred_xstart": out["pred_xstart"]}
def training_losses(self, model, x_start, t, model_kwargs=None, noise=None, channel_balancing_fn=None):