diff --git a/codes/models/vqvae/vqvae_3_hardswitch.py b/codes/models/vqvae/vqvae_3_hardswitch.py index 9374b432..1222a5ba 100644 --- a/codes/models/vqvae/vqvae_3_hardswitch.py +++ b/codes/models/vqvae/vqvae_3_hardswitch.py @@ -1,4 +1,5 @@ import torch +import torchvision from torch import nn from models.switched_conv.switched_conv_hard_routing import SwitchedConvHardRouting, \