diff --git a/codes/models/switched_conv/switched_conv_hard_routing.py b/codes/models/switched_conv/switched_conv_hard_routing.py index d87d4909..f6c48db0 100644 --- a/codes/models/switched_conv/switched_conv_hard_routing.py +++ b/codes/models/switched_conv/switched_conv_hard_routing.py @@ -185,10 +185,10 @@ class SwitchedConvHardRouting(nn.Module): self.coupler = Conv2d(coupler_dim_in, breadth, kernel_size=1, stride=self.stride) elif coupler_mode == 'lambda': self.coupler = nn.Sequential(nn.Conv2d(coupler_dim_in, coupler_dim_in, 1), - nn.BatchNorm2d(coupler_dim_in), + nn.GroupNorm(16, coupler_dim_in), nn.ReLU(), LambdaLayer(dim=coupler_dim_in, dim_out=breadth, r=23, dim_k=16, heads=2, dim_u=1), - nn.BatchNorm2d(breadth), + nn.GroupNorm(16, breadth), nn.ReLU(), Conv2d(breadth, breadth, 1, stride=self.stride)) else: @@ -240,7 +240,7 @@ class SwitchedConvHardRouting(nn.Module): self.last_select = selector.detach().clone() self.latest_masks = (selector.max(dim=1, keepdim=True)[0].repeat(1,self.breadth,1,1) == selector).float().argmax(dim=1) - if False: + if True: # This is a custom CUDA implementation which should be faster and less memory intensive (once completed). return SwitchedConvHardRoutingFunction.apply(input, selector, self.weight, self.bias, self.stride) else: diff --git a/codes/models/vqvae/vqvae_3_hardswitch.py b/codes/models/vqvae/vqvae_3_hardswitch.py index 212e912c..77d9587b 100644 --- a/codes/models/vqvae/vqvae_3_hardswitch.py +++ b/codes/models/vqvae/vqvae_3_hardswitch.py @@ -214,7 +214,7 @@ def convert_weights(weights_file): from models.vqvae.vqvae_3 import VQVAE3 std_model = VQVAE3() std_model.load_state_dict(sd) - nsd = convert_conv_net_state_dict_to_switched_conv(std_model, 8, ['quantize_conv_t', 'quantize_conv_b', + nsd = convert_conv_net_state_dict_to_switched_conv(std_model, 16, ['quantize_conv_t', 'quantize_conv_b', 'enc_b.blocks.0', 'enc_t.blocks.0', 'conv.1', 'conv.3', 'initial_conv', 'final_conv']) torch.save(nsd, "converted.pth") @@ -229,9 +229,9 @@ def register_vqvae3_hard_switch(opt_net, opt): def performance_test(): cfg = { 'mode': 'lambda', - 'breadth': 8, - 'hard_enabled': False, - 'dropout': 0 + 'breadth': 16, + 'hard_enabled': True, + 'dropout': 0.4 } net = VQVAE3HardSwitch(cfg=cfg).to('cuda') loss = nn.L1Loss() @@ -250,5 +250,5 @@ def performance_test(): if __name__ == '__main__': #v = VQVAE3HardSwitch() #print(v(torch.randn(1,3,128,128))[0].shape) - convert_weights("../../../experiments/test_vqvae3.pth") - #performance_test() + #convert_weights("../../../experiments/vqvae_base.pth") + performance_test()