diff --git a/codes/models/vqvae/vqvae_3.py b/codes/models/vqvae/vqvae_3.py index 3c040775..dc3f92cb 100644 --- a/codes/models/vqvae/vqvae_3.py +++ b/codes/models/vqvae/vqvae_3.py @@ -170,9 +170,9 @@ class VQVAE3(nn.Module): @register_model -def register_vqvae_normalized(opt_net, opt): +def register_vqvae3(opt_net, opt): kw = opt_get(opt_net, ['kwargs'], {}) - return VQVAE(**kw) + return VQVAE3(**kw) if __name__ == '__main__':