Log codes when simply fetching codebook_indices
This commit is contained in:
parent
380a5d5475
commit
662920bde3
|
@ -188,6 +188,7 @@ class DiscreteVAE(nn.Module):
|
|||
img = self.norm(images)
|
||||
logits = self.encoder(img).permute((0,2,3,1) if len(img.shape) == 4 else (0,2,1))
|
||||
sampled, codes, _ = self.codebook(logits)
|
||||
self.log_codes(codes)
|
||||
return codes
|
||||
|
||||
def decode(
|
||||
|
@ -236,6 +237,7 @@ class DiscreteVAE(nn.Module):
|
|||
out = sampled
|
||||
for d in self.decoder:
|
||||
out = d(out)
|
||||
self.log_codes(codes)
|
||||
else:
|
||||
# This is non-differentiable, but gives a better idea of how the network is actually performing.
|
||||
out, _ = self.decode(codes)
|
||||
|
@ -243,9 +245,6 @@ class DiscreteVAE(nn.Module):
|
|||
# reconstruction loss
|
||||
recon_loss = self.loss_fn(img, out, reduction='none')
|
||||
|
||||
# This is so we can debug the distribution of codes being learned.
|
||||
self.log_codes(codes)
|
||||
|
||||
return recon_loss, commitment_loss, out
|
||||
|
||||
def log_codes(self, codes):
|
||||
|
|
Loading…
Reference in New Issue
Block a user