From 98057b651686f34c8de07d7d42826c74d303eb0b Mon Sep 17 00:00:00 2001 From: James Betker Date: Sat, 14 Aug 2021 23:43:01 -0600 Subject: [PATCH] Make lrdvae use quantized mode in eval() --- codes/models/gpt_voice/lucidrains_dvae.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/codes/models/gpt_voice/lucidrains_dvae.py b/codes/models/gpt_voice/lucidrains_dvae.py index 1dda5fdd..c59483df 100644 --- a/codes/models/gpt_voice/lucidrains_dvae.py +++ b/codes/models/gpt_voice/lucidrains_dvae.py @@ -158,6 +158,9 @@ class DiscreteVAE(nn.Module): images = self.decoder(image_embeds) return images + # Note: This module is not meant to be run in forward() except while training. It has special logic which performs + # evaluation using quantized values when it detects that it is being run in eval() mode, which will be substantially + # more lossy (but useful for determining network performance). def forward( self, img @@ -166,7 +169,12 @@ class DiscreteVAE(nn.Module): logits = self.encoder(img).permute((0,2,3,1) if len(img.shape) == 4 else (0,2,1)) sampled, commitment_loss, codes = self.codebook(logits) sampled = sampled.permute((0,3,1,2) if len(img.shape) == 4 else (0,2,1)) - out = self.decoder(sampled) + + if self.training: + out = self.decoder(sampled) + else: + # This is non-differentiable, but gives a better idea of how the network is actually performing. + out = self.decode(codes) # reconstruction loss recon_loss = self.loss_fn(img, out) @@ -195,5 +203,6 @@ if __name__ == '__main__': #o=v(torch.randn(1,3,256,256)) #print(o.shape) v = DiscreteVAE(channels=1, normalization=None, positional_dims=1, num_tokens=4096, codebook_dim=2048, hidden_dim=256) + v.eval() o=v(torch.randn(1,1,256)) print(o[-1].shape)