From 9fc3df3f5b0205aefa67c8e5152e21904578cccb Mon Sep 17 00:00:00 2001 From: James Betker Date: Sat, 13 Mar 2021 10:44:56 -0700 Subject: [PATCH] Switched conv: add conversion function with allowlist --- .../switched_conv_hard_routing.py | 49 ++++++++++++++++++- 1 file changed, 47 insertions(+), 2 deletions(-) diff --git a/codes/models/switched_conv/switched_conv_hard_routing.py b/codes/models/switched_conv/switched_conv_hard_routing.py index 4a48e650..5eecb313 100644 --- a/codes/models/switched_conv/switched_conv_hard_routing.py +++ b/codes/models/switched_conv/switched_conv_hard_routing.py @@ -179,7 +179,9 @@ class SwitchedConvHardRouting(nn.Module): include_coupler: bool = False, # A 'coupler' is a latent converter which can make any bxcxhxw tensor a compatible switchedconv selector by performing a linear 1x1 conv, softmax and interpolate. coupler_mode: str = 'standard', coupler_dim_in: int = 0, - hard_en=True): # A test switch that, when used in 'emulation mode' (where all convs are calculated using torch functions) computes soft-attention instead of hard-attention. + hard_en=True, # A test switch that, when used in 'emulation mode' (where all convs are calculated using torch functions) computes soft-attention instead of hard-attention. + emulate_swconv=True, # When set, performs a nn.Conv2d operation for each breadth. When false, uses the native cuda implementation which computes all switches concurrently. + ): super().__init__() self.in_channels = in_c self.out_channels = out_c @@ -291,18 +293,61 @@ def convert_conv_net_state_dict_to_switched_conv(module, switch_breadth, ignore_ return state_dict +# Given a state_dict and the module that that sd belongs to, strips out the specified Conv2d modules and replaces them +# with equivalent switched_conv modules. +def convert_net_to_switched_conv(module, switch_breadth, allow_list, dropout_rate=0.4, coupler_mode='lambda'): + print("CONVERTING MODEL TO SWITCHED_CONV MODE") + + # Next, convert the model itself: + full_paths = [n.split('.') for n in allow_list] + for modpath in full_paths: + mod = module + for sub in modpath[:-1]: + pmod = mod + mod = getattr(mod, sub) + old_conv = getattr(mod, modpath[-1]) + new_conv = SwitchedConvHardRouting('.'.join(modpath), old_conv.in_channels, old_conv.out_channels, old_conv.kernel_size[0], switch_breadth, old_conv.stride[0], old_conv.bias, + include_coupler=True, dropout_rate=dropout_rate, coupler_mode=coupler_mode) + new_conv = new_conv.to(old_conv.weight.device) + assert old_conv.dilation == 1 or old_conv.dilation == (1,1) or old_conv.dilation is None + if isinstance(mod, nn.Sequential): + # If we use the standard logic (in the else case) here, it reorders the sequential. + # Instead, extract the OrderedDict from the current sequential, replace the Conv inside that dict, then replace the entire sequential to keep the order. + emods = mod._modules + emods[modpath[-1]] = new_conv + delattr(pmod, modpath[-2]) + pmod.add_module(modpath[-2], nn.Sequential(emods)) + else: + delattr(mod, modpath[-1]) + mod.add_module(modpath[-1], new_conv) + + +def convert_state_dict_to_switched_conv(sd_file, switch_breadth, allow_list): + save = torch.load(sd_file) + sd = save['state_dict'] + converted = 0 + for cname in allow_list: + for sn in sd.keys(): + if cname in sn and sn.endswith('weight'): + sd[sn] = sd[sn].unsqueeze(2).repeat(1,1,switch_breadth,1,1) + converted += 1 + print(f"Converted {converted} parameters.") + torch.save(save, sd_file.replace('.pt', "_converted.pt")) + + def test_net(): for j in tqdm(range(100)): base_conv = Conv2d(32, 64, 3, stride=2, padding=1, bias=True).to('cuda') mod_conv = SwitchedConvHardRouting(32, 64, 3, breadth=8, stride=2, bias=True, include_coupler=True, coupler_dim_in=32, dropout_rate=.2).to('cuda') mod_sd = convert_conv_net_state_dict_to_switched_conv(base_conv, 8) mod_conv.load_state_dict(mod_sd, strict=False) - inp = torch.randn((128,32,128,128), device='cuda') + inp = torch.randn((128, 32, 128, 128), device='cuda') out1 = base_conv(inp) out2 = mod_conv(inp, None) compare = (out2+torch.rand_like(out2)*1e-6).detach() MSELoss()(out2, compare).backward() assert(torch.max(torch.abs(out1-out2)) < 1e-5) + if __name__ == '__main__': test_net() \ No newline at end of file