forked from mrq/DL-Art-School
q fix
This commit is contained in:
parent
d9747fe623
commit
3f6ecfe0db
|
@ -151,7 +151,7 @@ class DiscreteVAE(nn.Module):
|
|||
self.loss_fn = F.smooth_l1_loss if smooth_l1_loss else F.mse_loss
|
||||
|
||||
if use_lr_quantizer:
|
||||
self.codebook = VectorQuantize(dim=codebook_dim, codebook_size=num_tokens, **lr_quantizer_args)
|
||||
self.codebook = VectorQuantize(dim=codebook_dim, n_embed=num_tokens, **lr_quantizer_args)
|
||||
else:
|
||||
self.codebook = Quantize(codebook_dim, num_tokens, new_return_order=True)
|
||||
|
||||
|
@ -270,7 +270,7 @@ if __name__ == '__main__':
|
|||
#print(o.shape)
|
||||
v = DiscreteVAE(channels=80, normalization=None, positional_dims=1, num_tokens=8192, codebook_dim=2048,
|
||||
hidden_dim=512, num_resnet_blocks=3, kernel_size=3, num_layers=1, use_transposed_convs=False,
|
||||
encoder_choke=True, choke_dim=256)
|
||||
use_lr_quantizer=True)
|
||||
#v.load_state_dict(torch.load('../experiments/clips_dvae_8192_rev2.pth'))
|
||||
#v.eval()
|
||||
o=v(torch.randn(1,80,256))
|
||||
|
|
Loading…
Reference in New Issue
Block a user