diff --git a/codes/models/switched_conv/switched_conv_hard_routing.py b/codes/models/switched_conv/switched_conv_hard_routing.py index cb150095..54ac8168 100644 --- a/codes/models/switched_conv/switched_conv_hard_routing.py +++ b/codes/models/switched_conv/switched_conv_hard_routing.py @@ -185,10 +185,10 @@ class SwitchedConvHardRouting(nn.Module): self.coupler = Conv2d(coupler_dim_in, breadth, kernel_size=1, stride=self.stride) elif coupler_mode == 'lambda': self.coupler = nn.Sequential(nn.Conv2d(coupler_dim_in, coupler_dim_in, 1), - nn.GroupNorm(2, coupler_dim_in), + nn.BatchNorm2d(coupler_dim_in), nn.ReLU(), LambdaLayer(dim=coupler_dim_in, dim_out=breadth, r=23, dim_k=16, heads=2, dim_u=1), - nn.GroupNorm(1, breadth), + nn.BatchNorm2d(breadth), nn.ReLU(), Conv2d(breadth, breadth, 1, stride=self.stride)) else: diff --git a/codes/models/vqvae/vqvae.py b/codes/models/vqvae/vqvae.py index d1418d1d..f2b98ac4 100644 --- a/codes/models/vqvae/vqvae.py +++ b/codes/models/vqvae/vqvae.py @@ -246,4 +246,7 @@ class VQVAE(nn.Module): @register_model def register_vqvae(opt_net, opt): kw = opt_get(opt_net, ['kwargs'], {}) - return VQVAE(**kw) + vq = VQVAE(**kw) + if distributed.is_initialized() and distributed.get_world_size() > 1: + vq = torch.nn.SyncBatchNorm.convert_sync_batchnorm(vq) + return vq diff --git a/codes/models/vqvae/vqvae_3_hardswitch.py b/codes/models/vqvae/vqvae_3_hardswitch.py index 77d9587b..bc45d45f 100644 --- a/codes/models/vqvae/vqvae_3_hardswitch.py +++ b/codes/models/vqvae/vqvae_3_hardswitch.py @@ -4,6 +4,7 @@ from time import time import torch import torchvision from torch import nn +from torch.nn.parallel import distributed from tqdm import tqdm from models.switched_conv.switched_conv_hard_routing import SwitchedConvHardRouting, \ @@ -223,13 +224,16 @@ def convert_weights(weights_file): @register_model def register_vqvae3_hard_switch(opt_net, opt): kw = opt_get(opt_net, ['kwargs'], {}) - return VQVAE3HardSwitch(**kw) + vq = VQVAE3HardSwitch(**kw) + if distributed.is_initialized() and distributed.get_world_size() > 1: + vq = torch.nn.SyncBatchNorm.convert_sync_batchnorm(vq) + return vq def performance_test(): cfg = { 'mode': 'lambda', - 'breadth': 16, + 'breadth': 8, 'hard_enabled': True, 'dropout': 0.4 }