forked from mrq/DL-Art-School
Switched conv: add conversion function with allowlist
This commit is contained in:
parent
cf9a6da889
commit
9fc3df3f5b
|
@ -179,7 +179,9 @@ class SwitchedConvHardRouting(nn.Module):
|
||||||
include_coupler: bool = False, # A 'coupler' is a latent converter which can make any bxcxhxw tensor a compatible switchedconv selector by performing a linear 1x1 conv, softmax and interpolate.
|
include_coupler: bool = False, # A 'coupler' is a latent converter which can make any bxcxhxw tensor a compatible switchedconv selector by performing a linear 1x1 conv, softmax and interpolate.
|
||||||
coupler_mode: str = 'standard',
|
coupler_mode: str = 'standard',
|
||||||
coupler_dim_in: int = 0,
|
coupler_dim_in: int = 0,
|
||||||
hard_en=True): # A test switch that, when used in 'emulation mode' (where all convs are calculated using torch functions) computes soft-attention instead of hard-attention.
|
hard_en=True, # A test switch that, when used in 'emulation mode' (where all convs are calculated using torch functions) computes soft-attention instead of hard-attention.
|
||||||
|
emulate_swconv=True, # When set, performs a nn.Conv2d operation for each breadth. When false, uses the native cuda implementation which computes all switches concurrently.
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.in_channels = in_c
|
self.in_channels = in_c
|
||||||
self.out_channels = out_c
|
self.out_channels = out_c
|
||||||
|
@ -291,18 +293,61 @@ def convert_conv_net_state_dict_to_switched_conv(module, switch_breadth, ignore_
|
||||||
return state_dict
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
|
# Given a state_dict and the module that that sd belongs to, strips out the specified Conv2d modules and replaces them
|
||||||
|
# with equivalent switched_conv modules.
|
||||||
|
def convert_net_to_switched_conv(module, switch_breadth, allow_list, dropout_rate=0.4, coupler_mode='lambda'):
|
||||||
|
print("CONVERTING MODEL TO SWITCHED_CONV MODE")
|
||||||
|
|
||||||
|
# Next, convert the model itself:
|
||||||
|
full_paths = [n.split('.') for n in allow_list]
|
||||||
|
for modpath in full_paths:
|
||||||
|
mod = module
|
||||||
|
for sub in modpath[:-1]:
|
||||||
|
pmod = mod
|
||||||
|
mod = getattr(mod, sub)
|
||||||
|
old_conv = getattr(mod, modpath[-1])
|
||||||
|
new_conv = SwitchedConvHardRouting('.'.join(modpath), old_conv.in_channels, old_conv.out_channels, old_conv.kernel_size[0], switch_breadth, old_conv.stride[0], old_conv.bias,
|
||||||
|
include_coupler=True, dropout_rate=dropout_rate, coupler_mode=coupler_mode)
|
||||||
|
new_conv = new_conv.to(old_conv.weight.device)
|
||||||
|
assert old_conv.dilation == 1 or old_conv.dilation == (1,1) or old_conv.dilation is None
|
||||||
|
if isinstance(mod, nn.Sequential):
|
||||||
|
# If we use the standard logic (in the else case) here, it reorders the sequential.
|
||||||
|
# Instead, extract the OrderedDict from the current sequential, replace the Conv inside that dict, then replace the entire sequential to keep the order.
|
||||||
|
emods = mod._modules
|
||||||
|
emods[modpath[-1]] = new_conv
|
||||||
|
delattr(pmod, modpath[-2])
|
||||||
|
pmod.add_module(modpath[-2], nn.Sequential(emods))
|
||||||
|
else:
|
||||||
|
delattr(mod, modpath[-1])
|
||||||
|
mod.add_module(modpath[-1], new_conv)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_state_dict_to_switched_conv(sd_file, switch_breadth, allow_list):
|
||||||
|
save = torch.load(sd_file)
|
||||||
|
sd = save['state_dict']
|
||||||
|
converted = 0
|
||||||
|
for cname in allow_list:
|
||||||
|
for sn in sd.keys():
|
||||||
|
if cname in sn and sn.endswith('weight'):
|
||||||
|
sd[sn] = sd[sn].unsqueeze(2).repeat(1,1,switch_breadth,1,1)
|
||||||
|
converted += 1
|
||||||
|
print(f"Converted {converted} parameters.")
|
||||||
|
torch.save(save, sd_file.replace('.pt', "_converted.pt"))
|
||||||
|
|
||||||
|
|
||||||
def test_net():
|
def test_net():
|
||||||
for j in tqdm(range(100)):
|
for j in tqdm(range(100)):
|
||||||
base_conv = Conv2d(32, 64, 3, stride=2, padding=1, bias=True).to('cuda')
|
base_conv = Conv2d(32, 64, 3, stride=2, padding=1, bias=True).to('cuda')
|
||||||
mod_conv = SwitchedConvHardRouting(32, 64, 3, breadth=8, stride=2, bias=True, include_coupler=True, coupler_dim_in=32, dropout_rate=.2).to('cuda')
|
mod_conv = SwitchedConvHardRouting(32, 64, 3, breadth=8, stride=2, bias=True, include_coupler=True, coupler_dim_in=32, dropout_rate=.2).to('cuda')
|
||||||
mod_sd = convert_conv_net_state_dict_to_switched_conv(base_conv, 8)
|
mod_sd = convert_conv_net_state_dict_to_switched_conv(base_conv, 8)
|
||||||
mod_conv.load_state_dict(mod_sd, strict=False)
|
mod_conv.load_state_dict(mod_sd, strict=False)
|
||||||
inp = torch.randn((128,32,128,128), device='cuda')
|
inp = torch.randn((128, 32, 128, 128), device='cuda')
|
||||||
out1 = base_conv(inp)
|
out1 = base_conv(inp)
|
||||||
out2 = mod_conv(inp, None)
|
out2 = mod_conv(inp, None)
|
||||||
compare = (out2+torch.rand_like(out2)*1e-6).detach()
|
compare = (out2+torch.rand_like(out2)*1e-6).detach()
|
||||||
MSELoss()(out2, compare).backward()
|
MSELoss()(out2, compare).backward()
|
||||||
assert(torch.max(torch.abs(out1-out2)) < 1e-5)
|
assert(torch.max(torch.abs(out1-out2)) < 1e-5)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_net()
|
test_net()
|
Loading…
Reference in New Issue
Block a user