Log codes when simply fetching codebook_indices

This commit is contained in:
James Betker 2021-12-06 09:21:43 -07:00
parent 380a5d5475
commit 662920bde3

View File

@ -188,6 +188,7 @@ class DiscreteVAE(nn.Module):
img = self.norm(images)
logits = self.encoder(img).permute((0,2,3,1) if len(img.shape) == 4 else (0,2,1))
sampled, codes, _ = self.codebook(logits)
self.log_codes(codes)
return codes
def decode(
@ -236,6 +237,7 @@ class DiscreteVAE(nn.Module):
out = sampled
for d in self.decoder:
out = d(out)
self.log_codes(codes)
else:
# This is non-differentiable, but gives a better idea of how the network is actually performing.
out, _ = self.decode(codes)
@ -243,9 +245,6 @@ class DiscreteVAE(nn.Module):
# reconstruction loss
recon_loss = self.loss_fn(img, out, reduction='none')
# This is so we can debug the distribution of codes being learned.
self.log_codes(codes)
return recon_loss, commitment_loss, out
def log_codes(self, codes):