Switched conv: add conversion function with allowlist

This commit is contained in:
James Betker 2021-03-13 10:44:56 -07:00
parent cf9a6da889
commit 9fc3df3f5b

View File

@ -179,7 +179,9 @@ class SwitchedConvHardRouting(nn.Module):
include_coupler: bool = False, # A 'coupler' is a latent converter which can make any bxcxhxw tensor a compatible switchedconv selector by performing a linear 1x1 conv, softmax and interpolate.
coupler_mode: str = 'standard',
coupler_dim_in: int = 0,
hard_en=True): # A test switch that, when used in 'emulation mode' (where all convs are calculated using torch functions) computes soft-attention instead of hard-attention.
hard_en=True, # A test switch that, when used in 'emulation mode' (where all convs are calculated using torch functions) computes soft-attention instead of hard-attention.
emulate_swconv=True, # When set, performs a nn.Conv2d operation for each breadth. When false, uses the native cuda implementation which computes all switches concurrently.
):
super().__init__()
self.in_channels = in_c
self.out_channels = out_c
@ -291,6 +293,48 @@ def convert_conv_net_state_dict_to_switched_conv(module, switch_breadth, ignore_
return state_dict
# Given a state_dict and the module that that sd belongs to, strips out the specified Conv2d modules and replaces them
# with equivalent switched_conv modules.
def convert_net_to_switched_conv(module, switch_breadth, allow_list, dropout_rate=0.4, coupler_mode='lambda'):
print("CONVERTING MODEL TO SWITCHED_CONV MODE")
# Next, convert the model itself:
full_paths = [n.split('.') for n in allow_list]
for modpath in full_paths:
mod = module
for sub in modpath[:-1]:
pmod = mod
mod = getattr(mod, sub)
old_conv = getattr(mod, modpath[-1])
new_conv = SwitchedConvHardRouting('.'.join(modpath), old_conv.in_channels, old_conv.out_channels, old_conv.kernel_size[0], switch_breadth, old_conv.stride[0], old_conv.bias,
include_coupler=True, dropout_rate=dropout_rate, coupler_mode=coupler_mode)
new_conv = new_conv.to(old_conv.weight.device)
assert old_conv.dilation == 1 or old_conv.dilation == (1,1) or old_conv.dilation is None
if isinstance(mod, nn.Sequential):
# If we use the standard logic (in the else case) here, it reorders the sequential.
# Instead, extract the OrderedDict from the current sequential, replace the Conv inside that dict, then replace the entire sequential to keep the order.
emods = mod._modules
emods[modpath[-1]] = new_conv
delattr(pmod, modpath[-2])
pmod.add_module(modpath[-2], nn.Sequential(emods))
else:
delattr(mod, modpath[-1])
mod.add_module(modpath[-1], new_conv)
def convert_state_dict_to_switched_conv(sd_file, switch_breadth, allow_list):
save = torch.load(sd_file)
sd = save['state_dict']
converted = 0
for cname in allow_list:
for sn in sd.keys():
if cname in sn and sn.endswith('weight'):
sd[sn] = sd[sn].unsqueeze(2).repeat(1,1,switch_breadth,1,1)
converted += 1
print(f"Converted {converted} parameters.")
torch.save(save, sd_file.replace('.pt', "_converted.pt"))
def test_net():
for j in tqdm(range(100)):
base_conv = Conv2d(32, 64, 3, stride=2, padding=1, bias=True).to('cuda')
@ -304,5 +348,6 @@ def test_net():
MSELoss()(out2, compare).backward()
assert(torch.max(torch.abs(out1-out2)) < 1e-5)
if __name__ == '__main__':
test_net()