diff --git a/codes/models/gpt_voice/lucidrains_dvae.py b/codes/models/gpt_voice/lucidrains_dvae.py index 67221723..f7b124b0 100644 --- a/codes/models/gpt_voice/lucidrains_dvae.py +++ b/codes/models/gpt_voice/lucidrains_dvae.py @@ -195,7 +195,10 @@ class DiscreteVAE(nn.Module): img_seq ): self.log_codes(img_seq) - image_embeds = self.codebook.embed_code(img_seq) + if hasattr(self.codebook, 'embed_code'): + image_embeds = self.codebook.embed_code(img_seq) + else: + image_embeds = F.embedding(img_seq, self.codebook.embed.transpose(1,0)) b, n, d = image_embeds.shape kwargs = {} @@ -273,5 +276,6 @@ if __name__ == '__main__': use_lr_quantizer=True) #v.load_state_dict(torch.load('../experiments/clips_dvae_8192_rev2.pth')) #v.eval() - o=v(torch.randn(1,80,256)) - print(o[-1].shape) + r,l,o=v(torch.randn(1,80,256)) + v.decode(torch.randint(0,8192,(1,256))) + print(o.shape, l.shape)