This commit is contained in:
James Betker 2021-11-24 00:19:47 -07:00
parent f1ed0588e3
commit 5956eb757c

View File

@ -198,7 +198,7 @@ class DiscreteVAE(nn.Module):
if hasattr(self.codebook, 'embed_code'):
image_embeds = self.codebook.embed_code(img_seq)
else:
image_embeds = F.embedding(img_seq, self.codebook.embed.transpose(1,0))
image_embeds = F.embedding(img_seq, self.codebook.codebook)
b, n, d = image_embeds.shape
kwargs = {}
@ -273,7 +273,7 @@ if __name__ == '__main__':
#print(o.shape)
v = DiscreteVAE(channels=80, normalization=None, positional_dims=1, num_tokens=8192, codebook_dim=2048,
hidden_dim=512, num_resnet_blocks=3, kernel_size=3, num_layers=1, use_transposed_convs=False,
use_lr_quantizer=True, lr_quantizer_args={'kmeans_init': True})
use_lr_quantizer=True)
#v.load_state_dict(torch.load('../experiments/clips_dvae_8192_rev2.pth'))
#v.eval()
r,l,o=v(torch.randn(1,80,256))