Convert lambda coupler to use groupnorm instead of batchnorm

This commit is contained in:
James Betker 2021-02-04 21:59:44 -07:00
parent 7070142805
commit 43da1f9c4b
2 changed files with 9 additions and 9 deletions

View File

@ -185,10 +185,10 @@ class SwitchedConvHardRouting(nn.Module):
self.coupler = Conv2d(coupler_dim_in, breadth, kernel_size=1, stride=self.stride)
elif coupler_mode == 'lambda':
self.coupler = nn.Sequential(nn.Conv2d(coupler_dim_in, coupler_dim_in, 1),
nn.BatchNorm2d(coupler_dim_in),
nn.GroupNorm(16, coupler_dim_in),
nn.ReLU(),
LambdaLayer(dim=coupler_dim_in, dim_out=breadth, r=23, dim_k=16, heads=2, dim_u=1),
nn.BatchNorm2d(breadth),
nn.GroupNorm(16, breadth),
nn.ReLU(),
Conv2d(breadth, breadth, 1, stride=self.stride))
else:
@ -240,7 +240,7 @@ class SwitchedConvHardRouting(nn.Module):
self.last_select = selector.detach().clone()
self.latest_masks = (selector.max(dim=1, keepdim=True)[0].repeat(1,self.breadth,1,1) == selector).float().argmax(dim=1)
if False:
if True:
# This is a custom CUDA implementation which should be faster and less memory intensive (once completed).
return SwitchedConvHardRoutingFunction.apply(input, selector, self.weight, self.bias, self.stride)
else:

View File

@ -214,7 +214,7 @@ def convert_weights(weights_file):
from models.vqvae.vqvae_3 import VQVAE3
std_model = VQVAE3()
std_model.load_state_dict(sd)
nsd = convert_conv_net_state_dict_to_switched_conv(std_model, 8, ['quantize_conv_t', 'quantize_conv_b',
nsd = convert_conv_net_state_dict_to_switched_conv(std_model, 16, ['quantize_conv_t', 'quantize_conv_b',
'enc_b.blocks.0', 'enc_t.blocks.0',
'conv.1', 'conv.3', 'initial_conv', 'final_conv'])
torch.save(nsd, "converted.pth")
@ -229,9 +229,9 @@ def register_vqvae3_hard_switch(opt_net, opt):
def performance_test():
cfg = {
'mode': 'lambda',
'breadth': 8,
'hard_enabled': False,
'dropout': 0
'breadth': 16,
'hard_enabled': True,
'dropout': 0.4
}
net = VQVAE3HardSwitch(cfg=cfg).to('cuda')
loss = nn.L1Loss()
@ -250,5 +250,5 @@ def performance_test():
if __name__ == '__main__':
#v = VQVAE3HardSwitch()
#print(v(torch.randn(1,3,128,128))[0].shape)
convert_weights("../../../experiments/test_vqvae3.pth")
#performance_test()
#convert_weights("../../../experiments/vqvae_base.pth")
performance_test()