2021-01-23 20:46:30 +00:00
import functools
import math
from collections import OrderedDict
import torch
import torch . nn as nn
2021-01-23 21:57:57 +00:00
from lambda_networks import LambdaLayer
2021-01-23 20:46:30 +00:00
from torch . nn import init , Conv2d
import torch . nn . functional as F
class SwitchedConv ( nn . Module ) :
def __init__ ( self ,
in_channels : int ,
out_channels : int ,
kernel_size : int ,
switch_breadth : int ,
stride : int = 1 ,
padding : int = 0 ,
dilation : int = 1 ,
groups : int = 1 ,
bias : bool = True ,
padding_mode : str = ' zeros ' ,
include_coupler : bool = False , # A 'coupler' is a latent converter which can make any bxcxhxw tensor a compatible switchedconv selector by performing a linear 1x1 conv, softmax and interpolate.
2021-01-23 21:57:57 +00:00
coupler_mode : str = ' standard ' ,
2021-01-23 20:46:30 +00:00
coupler_dim_in : int = 0 ) :
super ( ) . __init__ ( )
self . in_channels = in_channels
self . out_channels = out_channels
self . kernel_size = kernel_size
self . stride = stride
self . padding = padding
self . dilation = dilation
self . padding_mode = padding_mode
self . groups = groups
if include_coupler :
2021-01-23 21:57:57 +00:00
if coupler_mode == ' standard ' :
self . coupler = Conv2d ( coupler_dim_in , switch_breadth , kernel_size = 1 )
elif coupler_mode == ' lambda ' :
self . coupler = LambdaLayer ( dim = coupler_dim_in , dim_out = switch_breadth , r = 23 , dim_k = 16 , heads = 2 , dim_u = 1 )
2021-01-23 20:46:30 +00:00
else :
self . coupler = None
self . weights = nn . ParameterList ( [ nn . Parameter ( torch . Tensor ( out_channels , in_channels / / groups , kernel_size , kernel_size ) ) for _ in range ( switch_breadth ) ] )
if bias :
self . bias = nn . Parameter ( torch . Tensor ( out_channels ) )
else :
self . register_parameter ( ' bias ' , None )
self . reset_parameters ( )
def reset_parameters ( self ) - > None :
for w in self . weights :
init . kaiming_uniform_ ( w , a = math . sqrt ( 5 ) )
if self . bias is not None :
fan_in , _ = init . _calculate_fan_in_and_fan_out ( self . weights [ 0 ] )
bound = 1 / math . sqrt ( fan_in )
init . uniform_ ( self . bias , - bound , bound )
2021-01-23 21:57:57 +00:00
def forward ( self , inp , selector = None ) :
2021-01-23 20:46:30 +00:00
if self . coupler :
2021-01-23 21:57:57 +00:00
if selector is None : # A coupler can convert from any input to a selector, so 'None' is allowed.
selector = inp
2021-01-23 20:46:30 +00:00
selector = F . softmax ( self . coupler ( selector ) , dim = 1 )
out_shape = [ s / / self . stride for s in inp . shape [ 2 : ] ]
if selector . shape [ 2 ] != out_shape [ 0 ] or selector . shape [ 3 ] != out_shape [ 1 ] :
selector = F . interpolate ( selector , size = out_shape , mode = " nearest " )
2021-01-23 21:57:57 +00:00
assert selector is not None
2021-01-23 20:46:30 +00:00
conv_results = [ ]
for i , w in enumerate ( self . weights ) :
conv_results . append ( F . conv2d ( inp , w , self . bias , self . stride , self . padding , self . dilation , self . groups ) * selector [ : , i ] . unsqueeze ( 1 ) )
return torch . stack ( conv_results , dim = - 1 ) . sum ( dim = - 1 )
# Given a state_dict and the module that that sd belongs to, strips out all Conv2d.weight parameters and replaces them
# with the equivalent SwitchedConv.weight parameters. Does not create coupler params.
def convert_conv_net_state_dict_to_switched_conv ( module , switch_breadth , ignore_list = [ ] ) :
state_dict = module . state_dict ( )
for name , m in module . named_modules ( ) :
ignored = False
for smod in ignore_list :
if smod in name :
ignored = True
continue
if ignored :
continue
if isinstance ( m , nn . Conv2d ) :
if name == ' ' :
basename = ' weight '
modname = ' weights '
else :
basename = f ' { name } .weight '
modname = f ' { name } .weights '
cnv_weights = state_dict [ basename ]
del state_dict [ basename ]
for j in range ( switch_breadth ) :
state_dict [ f ' { modname } . { j } ' ] = cnv_weights . clone ( )
return state_dict
def test_net ( ) :
base_conv = Conv2d ( 32 , 64 , 3 , stride = 2 , padding = 1 , bias = True ) . to ( ' cuda ' )
mod_conv = SwitchedConv ( 32 , 64 , 3 , switch_breadth = 8 , stride = 2 , padding = 1 , bias = True , include_coupler = True , coupler_dim_in = 128 ) . to ( ' cuda ' )
mod_sd = convert_conv_net_state_dict_to_switched_conv ( base_conv , 8 )
mod_conv . load_state_dict ( mod_sd , strict = False )
inp = torch . randn ( ( 8 , 32 , 128 , 128 ) , device = ' cuda ' )
sel = torch . randn ( ( 8 , 128 , 32 , 32 ) , device = ' cuda ' )
out1 = base_conv ( inp )
out2 = mod_conv ( inp , sel )
assert ( torch . max ( torch . abs ( out1 - out2 ) ) < 1e-6 )
def perform_conversion ( ) :
sd = torch . load ( " ../experiments/rrdb_imgset_226500_generator.pth " )
load_net_clean = OrderedDict ( ) # remove unnecessary 'module.'
for k , v in sd . items ( ) :
if k . startswith ( ' module. ' ) :
load_net_clean [ k . replace ( ' module. ' , ' ' ) ] = v
else :
load_net_clean [ k ] = v
sd = load_net_clean
import models . RRDBNet_arch as rrdb
block = functools . partial ( rrdb . RRDBWithBypass )
mod = rrdb . RRDBNet ( in_channels = 3 , out_channels = 3 ,
mid_channels = 64 , num_blocks = 23 , body_block = block , scale = 2 , initial_stride = 2 )
mod . load_state_dict ( sd )
converted = convert_conv_net_state_dict_to_switched_conv ( mod , 8 , [ ' body. ' , ' conv_first ' , ' resnet_encoder ' ] )
torch . save ( converted , " ../experiments/rrdb_imgset_226500_generator_converted.pth " )
if __name__ == ' __main__ ' :
perform_conversion ( )