Fix dvae test failure

This commit is contained in:
James Betker 2021-11-18 00:58:36 -07:00
parent 019acfa4c5
commit 1287915f3c

View File

@ -228,7 +228,7 @@ class DiscreteVAE(nn.Module):
out = d(out)
else:
# This is non-differentiable, but gives a better idea of how the network is actually performing.
out = self.decode(codes)
out, _ = self.decode(codes)
# reconstruction loss
recon_loss = self.loss_fn(img, out, reduction='none')