another fix

This commit is contained in:
James Betker 2021-11-24 00:11:21 -07:00
parent 7a3c4a4fc6
commit f1ed0588e3

View File

@ -151,7 +151,7 @@ class DiscreteVAE(nn.Module):
self.loss_fn = F.smooth_l1_loss if smooth_l1_loss else F.mse_loss self.loss_fn = F.smooth_l1_loss if smooth_l1_loss else F.mse_loss
if use_lr_quantizer: if use_lr_quantizer:
self.codebook = VectorQuantize(dim=codebook_dim, n_embed=num_tokens, **lr_quantizer_args) self.codebook = VectorQuantize(dim=codebook_dim, codebook_size=num_tokens, **lr_quantizer_args)
else: else:
self.codebook = Quantize(codebook_dim, num_tokens, new_return_order=True) self.codebook = Quantize(codebook_dim, num_tokens, new_return_order=True)
@ -273,7 +273,7 @@ if __name__ == '__main__':
#print(o.shape) #print(o.shape)
v = DiscreteVAE(channels=80, normalization=None, positional_dims=1, num_tokens=8192, codebook_dim=2048, v = DiscreteVAE(channels=80, normalization=None, positional_dims=1, num_tokens=8192, codebook_dim=2048,
hidden_dim=512, num_resnet_blocks=3, kernel_size=3, num_layers=1, use_transposed_convs=False, hidden_dim=512, num_resnet_blocks=3, kernel_size=3, num_layers=1, use_transposed_convs=False,
use_lr_quantizer=True) use_lr_quantizer=True, lr_quantizer_args={'kmeans_init': True})
#v.load_state_dict(torch.load('../experiments/clips_dvae_8192_rev2.pth')) #v.load_state_dict(torch.load('../experiments/clips_dvae_8192_rev2.pth'))
#v.eval() #v.eval()
r,l,o=v(torch.randn(1,80,256)) r,l,o=v(torch.randn(1,80,256))