diff --git a/codes/models/switched_conv/switched_conv_hard_routing.py b/codes/models/switched_conv/switched_conv_hard_routing.py index 54ac8168..c6c62ed5 100644 --- a/codes/models/switched_conv/switched_conv_hard_routing.py +++ b/codes/models/switched_conv/switched_conv_hard_routing.py @@ -191,6 +191,17 @@ class SwitchedConvHardRouting(nn.Module): nn.BatchNorm2d(breadth), nn.ReLU(), Conv2d(breadth, breadth, 1, stride=self.stride)) + elif coupler_mode == 'lambda2': + self.coupler = nn.Sequential(nn.Conv2d(coupler_dim_in, coupler_dim_in, 1), + nn.GroupNorm(num_groups=2, num_channels=coupler_dim_in), + nn.ReLU(), + LambdaLayer(dim=coupler_dim_in, dim_out=coupler_dim_in, r=23, dim_k=16, heads=2, dim_u=1), + nn.GroupNorm(num_groups=2, num_channels=coupler_dim_in), + nn.ReLU(), + LambdaLayer(dim=coupler_dim_in, dim_out=breadth, r=23, dim_k=16, heads=2, dim_u=1), + nn.GroupNorm(num_groups=2, num_channels=breadth), + nn.ReLU(), + Conv2d(breadth, breadth, 1, stride=self.stride)) else: self.coupler = None self.gate = HardRoutingGate(breadth, hard_en=hard_en) @@ -240,7 +251,7 @@ class SwitchedConvHardRouting(nn.Module): self.last_select = selector.detach().clone() self.latest_masks = (selector.max(dim=1, keepdim=True)[0].repeat(1,self.breadth,1,1) == selector).float().argmax(dim=1) - if True: + if False: # This is a custom CUDA implementation which should be faster and less memory intensive (once completed). return SwitchedConvHardRoutingFunction.apply(input, selector, self.weight, self.bias, self.stride) else: diff --git a/codes/models/vqvae/vqvae_3_hardswitch.py b/codes/models/vqvae/vqvae_3_hardswitch.py index bc45d45f..d1677a1b 100644 --- a/codes/models/vqvae/vqvae_3_hardswitch.py +++ b/codes/models/vqvae/vqvae_3_hardswitch.py @@ -232,7 +232,7 @@ def register_vqvae3_hard_switch(opt_net, opt): def performance_test(): cfg = { - 'mode': 'lambda', + 'mode': 'lambda2', 'breadth': 8, 'hard_enabled': True, 'dropout': 0.4