diff --git a/codes/models/vqvae/vqvae_3_hardswitch.py b/codes/models/vqvae/vqvae_3_hardswitch.py index cfdef76d..ff6a950e 100644 --- a/codes/models/vqvae/vqvae_3_hardswitch.py +++ b/codes/models/vqvae/vqvae_3_hardswitch.py @@ -231,8 +231,15 @@ def register_vqvae3_hard_switch(opt_net, opt): def performance_test(): + # For breadth=32: + # Custom_cuda_naive: 28.9s + # Torch_native: 29.2s + # + # For breadth=8 + # Custom_cuda_naive: 18.4s + # Torch_native: 10s cfg = { - 'mode': 'lambda2', + 'mode': 'lambda', 'breadth': 8, 'hard_enabled': True, 'dropout': 0.4