From 662920bde33c846a831e5132b5736e916cdeab79 Mon Sep 17 00:00:00 2001 From: James Betker Date: Mon, 6 Dec 2021 09:21:43 -0700 Subject: [PATCH] Log codes when simply fetching codebook_indices --- codes/models/gpt_voice/lucidrains_dvae.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/codes/models/gpt_voice/lucidrains_dvae.py b/codes/models/gpt_voice/lucidrains_dvae.py index 888753c0..e60271b3 100644 --- a/codes/models/gpt_voice/lucidrains_dvae.py +++ b/codes/models/gpt_voice/lucidrains_dvae.py @@ -188,6 +188,7 @@ class DiscreteVAE(nn.Module): img = self.norm(images) logits = self.encoder(img).permute((0,2,3,1) if len(img.shape) == 4 else (0,2,1)) sampled, codes, _ = self.codebook(logits) + self.log_codes(codes) return codes def decode( @@ -236,6 +237,7 @@ class DiscreteVAE(nn.Module): out = sampled for d in self.decoder: out = d(out) + self.log_codes(codes) else: # This is non-differentiable, but gives a better idea of how the network is actually performing. out, _ = self.decode(codes) @@ -243,9 +245,6 @@ class DiscreteVAE(nn.Module): # reconstruction loss recon_loss = self.loss_fn(img, out, reduction='none') - # This is so we can debug the distribution of codes being learned. - self.log_codes(codes) - return recon_loss, commitment_loss, out def log_codes(self, codes):