From e7be4bdff33f02f087b76a094f0e7b1c2585d13c Mon Sep 17 00:00:00 2001 From: James Betker Date: Fri, 5 Feb 2021 08:43:07 -0700 Subject: [PATCH] Revert --- codes/models/switched_conv/switched_conv_hard_routing.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/codes/models/switched_conv/switched_conv_hard_routing.py b/codes/models/switched_conv/switched_conv_hard_routing.py index bb19c8e9..689504da 100644 --- a/codes/models/switched_conv/switched_conv_hard_routing.py +++ b/codes/models/switched_conv/switched_conv_hard_routing.py @@ -185,10 +185,10 @@ class SwitchedConvHardRouting(nn.Module): self.coupler = Conv2d(coupler_dim_in, breadth, kernel_size=1, stride=self.stride) elif coupler_mode == 'lambda': self.coupler = nn.Sequential(nn.Conv2d(coupler_dim_in, coupler_dim_in, 1), - nn.GroupNorm(num_groups=2, num_channels=coupler_dim_in), + nn.BatchNorm2d(coupler_dim_in), nn.ReLU(), LambdaLayer(dim=coupler_dim_in, dim_out=breadth, r=23, dim_k=16, heads=2, dim_u=1), - nn.GroupNorm(num_groups=1, num_channels=breadth), + nn.BatchNorm2d(breadth), nn.ReLU(), Conv2d(breadth, breadth, 1, stride=self.stride)) elif coupler_mode == 'lambda2':