diff --git a/codes/models/gpt_voice/lucidrains_dvae.py b/codes/models/gpt_voice/lucidrains_dvae.py index 1d21a3bc..67221723 100644 --- a/codes/models/gpt_voice/lucidrains_dvae.py +++ b/codes/models/gpt_voice/lucidrains_dvae.py @@ -151,7 +151,7 @@ class DiscreteVAE(nn.Module): self.loss_fn = F.smooth_l1_loss if smooth_l1_loss else F.mse_loss if use_lr_quantizer: - self.codebook = VectorQuantize(dim=codebook_dim, codebook_size=num_tokens, **lr_quantizer_args) + self.codebook = VectorQuantize(dim=codebook_dim, n_embed=num_tokens, **lr_quantizer_args) else: self.codebook = Quantize(codebook_dim, num_tokens, new_return_order=True) @@ -270,7 +270,7 @@ if __name__ == '__main__': #print(o.shape) v = DiscreteVAE(channels=80, normalization=None, positional_dims=1, num_tokens=8192, codebook_dim=2048, hidden_dim=512, num_resnet_blocks=3, kernel_size=3, num_layers=1, use_transposed_convs=False, - encoder_choke=True, choke_dim=256) + use_lr_quantizer=True) #v.load_state_dict(torch.load('../experiments/clips_dvae_8192_rev2.pth')) #v.eval() o=v(torch.randn(1,80,256))