Fix dvae test failure

This commit is contained in:
James Betker 2021-11-18 00:58:36 -07:00
parent 019acfa4c5
commit 1287915f3c

View File

@ -228,7 +228,7 @@ class DiscreteVAE(nn.Module):
out = d(out) out = d(out)
else: else:
# This is non-differentiable, but gives a better idea of how the network is actually performing. # This is non-differentiable, but gives a better idea of how the network is actually performing.
out = self.decode(codes) out, _ = self.decode(codes)
# reconstruction loss # reconstruction loss
recon_loss = self.loss_fn(img, out, reduction='none') recon_loss = self.loss_fn(img, out, reduction='none')