From 7a3c4a4fc6cf7693e8c5792c13a4e1ad32c140eb Mon Sep 17 00:00:00 2001 From: James Betker Date: Wed, 24 Nov 2021 00:01:26 -0700 Subject: [PATCH] Fix lr quantizer decode --- codes/models/gpt_voice/lucidrains_dvae.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/codes/models/gpt_voice/lucidrains_dvae.py b/codes/models/gpt_voice/lucidrains_dvae.py index 67221723..f7b124b0 100644 --- a/codes/models/gpt_voice/lucidrains_dvae.py +++ b/codes/models/gpt_voice/lucidrains_dvae.py @@ -195,7 +195,10 @@ class DiscreteVAE(nn.Module): img_seq ): self.log_codes(img_seq) - image_embeds = self.codebook.embed_code(img_seq) + if hasattr(self.codebook, 'embed_code'): + image_embeds = self.codebook.embed_code(img_seq) + else: + image_embeds = F.embedding(img_seq, self.codebook.embed.transpose(1,0)) b, n, d = image_embeds.shape kwargs = {} @@ -273,5 +276,6 @@ if __name__ == '__main__': use_lr_quantizer=True) #v.load_state_dict(torch.load('../experiments/clips_dvae_8192_rev2.pth')) #v.eval() - o=v(torch.randn(1,80,256)) - print(o[-1].shape) + r,l,o=v(torch.randn(1,80,256)) + v.decode(torch.randint(0,8192,(1,256))) + print(o.shape, l.shape)