forked from mrq/DL-Art-School
Fix lr quantizer decode
This commit is contained in:
parent
3f6ecfe0db
commit
7a3c4a4fc6
|
@ -195,7 +195,10 @@ class DiscreteVAE(nn.Module):
|
||||||
img_seq
|
img_seq
|
||||||
):
|
):
|
||||||
self.log_codes(img_seq)
|
self.log_codes(img_seq)
|
||||||
image_embeds = self.codebook.embed_code(img_seq)
|
if hasattr(self.codebook, 'embed_code'):
|
||||||
|
image_embeds = self.codebook.embed_code(img_seq)
|
||||||
|
else:
|
||||||
|
image_embeds = F.embedding(img_seq, self.codebook.embed.transpose(1,0))
|
||||||
b, n, d = image_embeds.shape
|
b, n, d = image_embeds.shape
|
||||||
|
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
|
@ -273,5 +276,6 @@ if __name__ == '__main__':
|
||||||
use_lr_quantizer=True)
|
use_lr_quantizer=True)
|
||||||
#v.load_state_dict(torch.load('../experiments/clips_dvae_8192_rev2.pth'))
|
#v.load_state_dict(torch.load('../experiments/clips_dvae_8192_rev2.pth'))
|
||||||
#v.eval()
|
#v.eval()
|
||||||
o=v(torch.randn(1,80,256))
|
r,l,o=v(torch.randn(1,80,256))
|
||||||
print(o[-1].shape)
|
v.decode(torch.randint(0,8192,(1,256)))
|
||||||
|
print(o.shape, l.shape)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user