forked from mrq/DL-Art-School
fix bug
This commit is contained in:
parent
07bdd865dc
commit
6e57eaa186
|
@ -787,7 +787,7 @@ class GaussianDiffusion:
|
|||
|
||||
# At the first timestep return the decoder NLL,
|
||||
# otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
|
||||
output = th.where((t == 0), decoder_nll, kl)
|
||||
output = th.where((t == 0).view(-1, 1, 1), decoder_nll, kl)
|
||||
return {"output": output, "pred_xstart": out["pred_xstart"]}
|
||||
|
||||
def training_losses(self, model, x_start, t, model_kwargs=None, noise=None, channel_balancing_fn=None):
|
||||
|
|
Loading…
Reference in New Issue
Block a user