diff --git a/codes/models/switched_conv/switched_conv_hard_routing.py b/codes/models/switched_conv/switched_conv_hard_routing.py index c6c62ed5..bb19c8e9 100644 --- a/codes/models/switched_conv/switched_conv_hard_routing.py +++ b/codes/models/switched_conv/switched_conv_hard_routing.py @@ -185,10 +185,10 @@ class SwitchedConvHardRouting(nn.Module): self.coupler = Conv2d(coupler_dim_in, breadth, kernel_size=1, stride=self.stride) elif coupler_mode == 'lambda': self.coupler = nn.Sequential(nn.Conv2d(coupler_dim_in, coupler_dim_in, 1), - nn.BatchNorm2d(coupler_dim_in), + nn.GroupNorm(num_groups=2, num_channels=coupler_dim_in), nn.ReLU(), LambdaLayer(dim=coupler_dim_in, dim_out=breadth, r=23, dim_k=16, heads=2, dim_u=1), - nn.BatchNorm2d(breadth), + nn.GroupNorm(num_groups=1, num_channels=breadth), nn.ReLU(), Conv2d(breadth, breadth, 1, stride=self.stride)) elif coupler_mode == 'lambda2': @@ -199,7 +199,7 @@ class SwitchedConvHardRouting(nn.Module): nn.GroupNorm(num_groups=2, num_channels=coupler_dim_in), nn.ReLU(), LambdaLayer(dim=coupler_dim_in, dim_out=breadth, r=23, dim_k=16, heads=2, dim_u=1), - nn.GroupNorm(num_groups=2, num_channels=breadth), + nn.GroupNorm(num_groups=1, num_channels=breadth), nn.ReLU(), Conv2d(breadth, breadth, 1, stride=self.stride)) else: diff --git a/codes/models/vqvae/vqvae_3_hardswitch.py b/codes/models/vqvae/vqvae_3_hardswitch.py index d1677a1b..cfdef76d 100644 --- a/codes/models/vqvae/vqvae_3_hardswitch.py +++ b/codes/models/vqvae/vqvae_3_hardswitch.py @@ -3,8 +3,8 @@ from time import time import torch import torchvision +import torch.distributed as distributed from torch import nn -from torch.nn.parallel import distributed from tqdm import tqdm from models.switched_conv.switched_conv_hard_routing import SwitchedConvHardRouting, \