Switched conv: add conversion function with allowlist
This commit is contained in:
parent
cf9a6da889
commit
9fc3df3f5b
|
@ -179,7 +179,9 @@ class SwitchedConvHardRouting(nn.Module):
|
|||
include_coupler: bool = False, # A 'coupler' is a latent converter which can make any bxcxhxw tensor a compatible switchedconv selector by performing a linear 1x1 conv, softmax and interpolate.
|
||||
coupler_mode: str = 'standard',
|
||||
coupler_dim_in: int = 0,
|
||||
hard_en=True): # A test switch that, when used in 'emulation mode' (where all convs are calculated using torch functions) computes soft-attention instead of hard-attention.
|
||||
hard_en=True, # A test switch that, when used in 'emulation mode' (where all convs are calculated using torch functions) computes soft-attention instead of hard-attention.
|
||||
emulate_swconv=True, # When set, performs a nn.Conv2d operation for each breadth. When false, uses the native cuda implementation which computes all switches concurrently.
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_c
|
||||
self.out_channels = out_c
|
||||
|
@ -291,18 +293,61 @@ def convert_conv_net_state_dict_to_switched_conv(module, switch_breadth, ignore_
|
|||
return state_dict
|
||||
|
||||
|
||||
# Given a state_dict and the module that that sd belongs to, strips out the specified Conv2d modules and replaces them
|
||||
# with equivalent switched_conv modules.
|
||||
def convert_net_to_switched_conv(module, switch_breadth, allow_list, dropout_rate=0.4, coupler_mode='lambda'):
|
||||
print("CONVERTING MODEL TO SWITCHED_CONV MODE")
|
||||
|
||||
# Next, convert the model itself:
|
||||
full_paths = [n.split('.') for n in allow_list]
|
||||
for modpath in full_paths:
|
||||
mod = module
|
||||
for sub in modpath[:-1]:
|
||||
pmod = mod
|
||||
mod = getattr(mod, sub)
|
||||
old_conv = getattr(mod, modpath[-1])
|
||||
new_conv = SwitchedConvHardRouting('.'.join(modpath), old_conv.in_channels, old_conv.out_channels, old_conv.kernel_size[0], switch_breadth, old_conv.stride[0], old_conv.bias,
|
||||
include_coupler=True, dropout_rate=dropout_rate, coupler_mode=coupler_mode)
|
||||
new_conv = new_conv.to(old_conv.weight.device)
|
||||
assert old_conv.dilation == 1 or old_conv.dilation == (1,1) or old_conv.dilation is None
|
||||
if isinstance(mod, nn.Sequential):
|
||||
# If we use the standard logic (in the else case) here, it reorders the sequential.
|
||||
# Instead, extract the OrderedDict from the current sequential, replace the Conv inside that dict, then replace the entire sequential to keep the order.
|
||||
emods = mod._modules
|
||||
emods[modpath[-1]] = new_conv
|
||||
delattr(pmod, modpath[-2])
|
||||
pmod.add_module(modpath[-2], nn.Sequential(emods))
|
||||
else:
|
||||
delattr(mod, modpath[-1])
|
||||
mod.add_module(modpath[-1], new_conv)
|
||||
|
||||
|
||||
def convert_state_dict_to_switched_conv(sd_file, switch_breadth, allow_list):
|
||||
save = torch.load(sd_file)
|
||||
sd = save['state_dict']
|
||||
converted = 0
|
||||
for cname in allow_list:
|
||||
for sn in sd.keys():
|
||||
if cname in sn and sn.endswith('weight'):
|
||||
sd[sn] = sd[sn].unsqueeze(2).repeat(1,1,switch_breadth,1,1)
|
||||
converted += 1
|
||||
print(f"Converted {converted} parameters.")
|
||||
torch.save(save, sd_file.replace('.pt', "_converted.pt"))
|
||||
|
||||
|
||||
def test_net():
|
||||
for j in tqdm(range(100)):
|
||||
base_conv = Conv2d(32, 64, 3, stride=2, padding=1, bias=True).to('cuda')
|
||||
mod_conv = SwitchedConvHardRouting(32, 64, 3, breadth=8, stride=2, bias=True, include_coupler=True, coupler_dim_in=32, dropout_rate=.2).to('cuda')
|
||||
mod_sd = convert_conv_net_state_dict_to_switched_conv(base_conv, 8)
|
||||
mod_conv.load_state_dict(mod_sd, strict=False)
|
||||
inp = torch.randn((128,32,128,128), device='cuda')
|
||||
inp = torch.randn((128, 32, 128, 128), device='cuda')
|
||||
out1 = base_conv(inp)
|
||||
out2 = mod_conv(inp, None)
|
||||
compare = (out2+torch.rand_like(out2)*1e-6).detach()
|
||||
MSELoss()(out2, compare).backward()
|
||||
assert(torch.max(torch.abs(out1-out2)) < 1e-5)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_net()
|
Loading…
Reference in New Issue
Block a user