Back to groupnorm

This commit is contained in:
James Betker 2021-02-05 08:42:11 -07:00
parent 336f807c8e
commit 6dec1f5968
2 changed files with 4 additions and 4 deletions

View File

@ -185,10 +185,10 @@ class SwitchedConvHardRouting(nn.Module):
self.coupler = Conv2d(coupler_dim_in, breadth, kernel_size=1, stride=self.stride) self.coupler = Conv2d(coupler_dim_in, breadth, kernel_size=1, stride=self.stride)
elif coupler_mode == 'lambda': elif coupler_mode == 'lambda':
self.coupler = nn.Sequential(nn.Conv2d(coupler_dim_in, coupler_dim_in, 1), self.coupler = nn.Sequential(nn.Conv2d(coupler_dim_in, coupler_dim_in, 1),
nn.BatchNorm2d(coupler_dim_in), nn.GroupNorm(num_groups=2, num_channels=coupler_dim_in),
nn.ReLU(), nn.ReLU(),
LambdaLayer(dim=coupler_dim_in, dim_out=breadth, r=23, dim_k=16, heads=2, dim_u=1), LambdaLayer(dim=coupler_dim_in, dim_out=breadth, r=23, dim_k=16, heads=2, dim_u=1),
nn.BatchNorm2d(breadth), nn.GroupNorm(num_groups=1, num_channels=breadth),
nn.ReLU(), nn.ReLU(),
Conv2d(breadth, breadth, 1, stride=self.stride)) Conv2d(breadth, breadth, 1, stride=self.stride))
elif coupler_mode == 'lambda2': elif coupler_mode == 'lambda2':
@ -199,7 +199,7 @@ class SwitchedConvHardRouting(nn.Module):
nn.GroupNorm(num_groups=2, num_channels=coupler_dim_in), nn.GroupNorm(num_groups=2, num_channels=coupler_dim_in),
nn.ReLU(), nn.ReLU(),
LambdaLayer(dim=coupler_dim_in, dim_out=breadth, r=23, dim_k=16, heads=2, dim_u=1), LambdaLayer(dim=coupler_dim_in, dim_out=breadth, r=23, dim_k=16, heads=2, dim_u=1),
nn.GroupNorm(num_groups=2, num_channels=breadth), nn.GroupNorm(num_groups=1, num_channels=breadth),
nn.ReLU(), nn.ReLU(),
Conv2d(breadth, breadth, 1, stride=self.stride)) Conv2d(breadth, breadth, 1, stride=self.stride))
else: else:

View File

@ -3,8 +3,8 @@ from time import time
import torch import torch
import torchvision import torchvision
import torch.distributed as distributed
from torch import nn from torch import nn
from torch.nn.parallel import distributed
from tqdm import tqdm from tqdm import tqdm
from models.switched_conv.switched_conv_hard_routing import SwitchedConvHardRouting, \ from models.switched_conv.switched_conv_hard_routing import SwitchedConvHardRouting, \