Log codes when simply fetching codebook_indices

This commit is contained in:
James Betker 2021-12-06 09:21:43 -07:00
parent 380a5d5475
commit 662920bde3

View File

@ -188,6 +188,7 @@ class DiscreteVAE(nn.Module):
img = self.norm(images) img = self.norm(images)
logits = self.encoder(img).permute((0,2,3,1) if len(img.shape) == 4 else (0,2,1)) logits = self.encoder(img).permute((0,2,3,1) if len(img.shape) == 4 else (0,2,1))
sampled, codes, _ = self.codebook(logits) sampled, codes, _ = self.codebook(logits)
self.log_codes(codes)
return codes return codes
def decode( def decode(
@ -236,6 +237,7 @@ class DiscreteVAE(nn.Module):
out = sampled out = sampled
for d in self.decoder: for d in self.decoder:
out = d(out) out = d(out)
self.log_codes(codes)
else: else:
# This is non-differentiable, but gives a better idea of how the network is actually performing. # This is non-differentiable, but gives a better idea of how the network is actually performing.
out, _ = self.decode(codes) out, _ = self.decode(codes)
@ -243,9 +245,6 @@ class DiscreteVAE(nn.Module):
# reconstruction loss # reconstruction loss
recon_loss = self.loss_fn(img, out, reduction='none') recon_loss = self.loss_fn(img, out, reduction='none')
# This is so we can debug the distribution of codes being learned.
self.log_codes(codes)
return recon_loss, commitment_loss, out return recon_loss, commitment_loss, out
def log_codes(self, codes): def log_codes(self, codes):