Use syncbatchnorm instead
This commit is contained in:
parent
bb79fafb89
commit
025a5867c4
|
@ -185,10 +185,10 @@ class SwitchedConvHardRouting(nn.Module):
|
||||||
self.coupler = Conv2d(coupler_dim_in, breadth, kernel_size=1, stride=self.stride)
|
self.coupler = Conv2d(coupler_dim_in, breadth, kernel_size=1, stride=self.stride)
|
||||||
elif coupler_mode == 'lambda':
|
elif coupler_mode == 'lambda':
|
||||||
self.coupler = nn.Sequential(nn.Conv2d(coupler_dim_in, coupler_dim_in, 1),
|
self.coupler = nn.Sequential(nn.Conv2d(coupler_dim_in, coupler_dim_in, 1),
|
||||||
nn.GroupNorm(2, coupler_dim_in),
|
nn.BatchNorm2d(coupler_dim_in),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
LambdaLayer(dim=coupler_dim_in, dim_out=breadth, r=23, dim_k=16, heads=2, dim_u=1),
|
LambdaLayer(dim=coupler_dim_in, dim_out=breadth, r=23, dim_k=16, heads=2, dim_u=1),
|
||||||
nn.GroupNorm(1, breadth),
|
nn.BatchNorm2d(breadth),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
Conv2d(breadth, breadth, 1, stride=self.stride))
|
Conv2d(breadth, breadth, 1, stride=self.stride))
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -246,4 +246,7 @@ class VQVAE(nn.Module):
|
||||||
@register_model
|
@register_model
|
||||||
def register_vqvae(opt_net, opt):
|
def register_vqvae(opt_net, opt):
|
||||||
kw = opt_get(opt_net, ['kwargs'], {})
|
kw = opt_get(opt_net, ['kwargs'], {})
|
||||||
return VQVAE(**kw)
|
vq = VQVAE(**kw)
|
||||||
|
if distributed.is_initialized() and distributed.get_world_size() > 1:
|
||||||
|
vq = torch.nn.SyncBatchNorm.convert_sync_batchnorm(vq)
|
||||||
|
return vq
|
||||||
|
|
|
@ -4,6 +4,7 @@ from time import time
|
||||||
import torch
|
import torch
|
||||||
import torchvision
|
import torchvision
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
from torch.nn.parallel import distributed
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from models.switched_conv.switched_conv_hard_routing import SwitchedConvHardRouting, \
|
from models.switched_conv.switched_conv_hard_routing import SwitchedConvHardRouting, \
|
||||||
|
@ -223,13 +224,16 @@ def convert_weights(weights_file):
|
||||||
@register_model
|
@register_model
|
||||||
def register_vqvae3_hard_switch(opt_net, opt):
|
def register_vqvae3_hard_switch(opt_net, opt):
|
||||||
kw = opt_get(opt_net, ['kwargs'], {})
|
kw = opt_get(opt_net, ['kwargs'], {})
|
||||||
return VQVAE3HardSwitch(**kw)
|
vq = VQVAE3HardSwitch(**kw)
|
||||||
|
if distributed.is_initialized() and distributed.get_world_size() > 1:
|
||||||
|
vq = torch.nn.SyncBatchNorm.convert_sync_batchnorm(vq)
|
||||||
|
return vq
|
||||||
|
|
||||||
|
|
||||||
def performance_test():
|
def performance_test():
|
||||||
cfg = {
|
cfg = {
|
||||||
'mode': 'lambda',
|
'mode': 'lambda',
|
||||||
'breadth': 16,
|
'breadth': 8,
|
||||||
'hard_enabled': True,
|
'hard_enabled': True,
|
||||||
'dropout': 0.4
|
'dropout': 0.4
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user