Back to groupnorm

This commit is contained in:
James Betker 2021-02-05 08:42:11 -07:00
parent 336f807c8e
commit 6dec1f5968
2 changed files with 4 additions and 4 deletions

View File

@ -185,10 +185,10 @@ class SwitchedConvHardRouting(nn.Module):
self.coupler = Conv2d(coupler_dim_in, breadth, kernel_size=1, stride=self.stride)
elif coupler_mode == 'lambda':
self.coupler = nn.Sequential(nn.Conv2d(coupler_dim_in, coupler_dim_in, 1),
nn.BatchNorm2d(coupler_dim_in),
nn.GroupNorm(num_groups=2, num_channels=coupler_dim_in),
nn.ReLU(),
LambdaLayer(dim=coupler_dim_in, dim_out=breadth, r=23, dim_k=16, heads=2, dim_u=1),
nn.BatchNorm2d(breadth),
nn.GroupNorm(num_groups=1, num_channels=breadth),
nn.ReLU(),
Conv2d(breadth, breadth, 1, stride=self.stride))
elif coupler_mode == 'lambda2':
@ -199,7 +199,7 @@ class SwitchedConvHardRouting(nn.Module):
nn.GroupNorm(num_groups=2, num_channels=coupler_dim_in),
nn.ReLU(),
LambdaLayer(dim=coupler_dim_in, dim_out=breadth, r=23, dim_k=16, heads=2, dim_u=1),
nn.GroupNorm(num_groups=2, num_channels=breadth),
nn.GroupNorm(num_groups=1, num_channels=breadth),
nn.ReLU(),
Conv2d(breadth, breadth, 1, stride=self.stride))
else:

View File

@ -3,8 +3,8 @@ from time import time
import torch
import torchvision
import torch.distributed as distributed
from torch import nn
from torch.nn.parallel import distributed
from tqdm import tqdm
from models.switched_conv.switched_conv_hard_routing import SwitchedConvHardRouting, \