Use syncbatchnorm instead

This commit is contained in:
James Betker 2021-02-04 22:26:36 -07:00
parent bb79fafb89
commit 025a5867c4
3 changed files with 12 additions and 5 deletions

View File

@ -185,10 +185,10 @@ class SwitchedConvHardRouting(nn.Module):
self.coupler = Conv2d(coupler_dim_in, breadth, kernel_size=1, stride=self.stride) self.coupler = Conv2d(coupler_dim_in, breadth, kernel_size=1, stride=self.stride)
elif coupler_mode == 'lambda': elif coupler_mode == 'lambda':
self.coupler = nn.Sequential(nn.Conv2d(coupler_dim_in, coupler_dim_in, 1), self.coupler = nn.Sequential(nn.Conv2d(coupler_dim_in, coupler_dim_in, 1),
nn.GroupNorm(2, coupler_dim_in), nn.BatchNorm2d(coupler_dim_in),
nn.ReLU(), nn.ReLU(),
LambdaLayer(dim=coupler_dim_in, dim_out=breadth, r=23, dim_k=16, heads=2, dim_u=1), LambdaLayer(dim=coupler_dim_in, dim_out=breadth, r=23, dim_k=16, heads=2, dim_u=1),
nn.GroupNorm(1, breadth), nn.BatchNorm2d(breadth),
nn.ReLU(), nn.ReLU(),
Conv2d(breadth, breadth, 1, stride=self.stride)) Conv2d(breadth, breadth, 1, stride=self.stride))
else: else:

View File

@ -246,4 +246,7 @@ class VQVAE(nn.Module):
@register_model @register_model
def register_vqvae(opt_net, opt): def register_vqvae(opt_net, opt):
kw = opt_get(opt_net, ['kwargs'], {}) kw = opt_get(opt_net, ['kwargs'], {})
return VQVAE(**kw) vq = VQVAE(**kw)
if distributed.is_initialized() and distributed.get_world_size() > 1:
vq = torch.nn.SyncBatchNorm.convert_sync_batchnorm(vq)
return vq

View File

@ -4,6 +4,7 @@ from time import time
import torch import torch
import torchvision import torchvision
from torch import nn from torch import nn
from torch.nn.parallel import distributed
from tqdm import tqdm from tqdm import tqdm
from models.switched_conv.switched_conv_hard_routing import SwitchedConvHardRouting, \ from models.switched_conv.switched_conv_hard_routing import SwitchedConvHardRouting, \
@ -223,13 +224,16 @@ def convert_weights(weights_file):
@register_model @register_model
def register_vqvae3_hard_switch(opt_net, opt): def register_vqvae3_hard_switch(opt_net, opt):
kw = opt_get(opt_net, ['kwargs'], {}) kw = opt_get(opt_net, ['kwargs'], {})
return VQVAE3HardSwitch(**kw) vq = VQVAE3HardSwitch(**kw)
if distributed.is_initialized() and distributed.get_world_size() > 1:
vq = torch.nn.SyncBatchNorm.convert_sync_batchnorm(vq)
return vq
def performance_test(): def performance_test():
cfg = { cfg = {
'mode': 'lambda', 'mode': 'lambda',
'breadth': 16, 'breadth': 8,
'hard_enabled': True, 'hard_enabled': True,
'dropout': 0.4 'dropout': 0.4
} }