Log codes when simply fetching codebook_indices
This commit is contained in:
parent
380a5d5475
commit
662920bde3
|
@ -188,6 +188,7 @@ class DiscreteVAE(nn.Module):
|
||||||
img = self.norm(images)
|
img = self.norm(images)
|
||||||
logits = self.encoder(img).permute((0,2,3,1) if len(img.shape) == 4 else (0,2,1))
|
logits = self.encoder(img).permute((0,2,3,1) if len(img.shape) == 4 else (0,2,1))
|
||||||
sampled, codes, _ = self.codebook(logits)
|
sampled, codes, _ = self.codebook(logits)
|
||||||
|
self.log_codes(codes)
|
||||||
return codes
|
return codes
|
||||||
|
|
||||||
def decode(
|
def decode(
|
||||||
|
@ -236,6 +237,7 @@ class DiscreteVAE(nn.Module):
|
||||||
out = sampled
|
out = sampled
|
||||||
for d in self.decoder:
|
for d in self.decoder:
|
||||||
out = d(out)
|
out = d(out)
|
||||||
|
self.log_codes(codes)
|
||||||
else:
|
else:
|
||||||
# This is non-differentiable, but gives a better idea of how the network is actually performing.
|
# This is non-differentiable, but gives a better idea of how the network is actually performing.
|
||||||
out, _ = self.decode(codes)
|
out, _ = self.decode(codes)
|
||||||
|
@ -243,9 +245,6 @@ class DiscreteVAE(nn.Module):
|
||||||
# reconstruction loss
|
# reconstruction loss
|
||||||
recon_loss = self.loss_fn(img, out, reduction='none')
|
recon_loss = self.loss_fn(img, out, reduction='none')
|
||||||
|
|
||||||
# This is so we can debug the distribution of codes being learned.
|
|
||||||
self.log_codes(codes)
|
|
||||||
|
|
||||||
return recon_loss, commitment_loss, out
|
return recon_loss, commitment_loss, out
|
||||||
|
|
||||||
def log_codes(self, codes):
|
def log_codes(self, codes):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user