diff --git a/codes/models/archs/SwitchedResidualGenerator_arch.py b/codes/models/archs/SwitchedResidualGenerator_arch.py index 9a472cae..fb5296a4 100644 --- a/codes/models/archs/SwitchedResidualGenerator_arch.py +++ b/codes/models/archs/SwitchedResidualGenerator_arch.py @@ -43,12 +43,13 @@ class ResidualBranch(nn.Module): return x * self.scale + self.bias -# VGG-style layer with Conv->BN->Activation->Conv(stride2)->BN->Activation +# VGG-style layer with Conv(stride2)->BN->Activation->Conv->BN->Activation +# Doubles the input filter count. class HalvingProcessingBlock(nn.Module): def __init__(self, filters): super(HalvingProcessingBlock, self).__init__() - self.bnconv1 = ConvBnLelu(filters, filters) - self.bnconv2 = ConvBnLelu(filters, filters * 2, stride=2) + self.bnconv1 = ConvBnLelu(filters, filters * 2, stride=2) + self.bnconv2 = ConvBnLelu(filters * 2, filters * 2) def forward(self, x): x = self.bnconv1(x) @@ -56,11 +57,12 @@ class HalvingProcessingBlock(nn.Module): class SwitchComputer(nn.Module): - def __init__(self, channels_in, filters, transform_block, transform_count, reductions, init_temp=20): + def __init__(self, channels_in, filters, transform_block, transform_count, reduction_blocks, processing_blocks=0, init_temp=20): super(SwitchComputer, self).__init__() self.filter_conv = ConvBnLelu(channels_in, filters) - self.blocks = nn.ModuleList([HalvingProcessingBlock(filters * 2 ** i) for i in range(reductions)]) - final_filters = filters * 2 ** reductions + self.reduction_blocks = nn.ModuleList([HalvingProcessingBlock(filters * 2 ** i) for i in range(reduction_blocks)]) + final_filters = filters * 2 ** reduction_blocks + self.processing_blocks = nn.ModuleList([ConvBnLelu(final_filters, final_filters) for i in range(processing_blocks)]) proc_block_filters = max(final_filters // 2, transform_count) self.proc_switch_conv = ConvBnLelu(final_filters, proc_block_filters) self.final_switch_conv = nn.Conv2d(proc_block_filters, transform_count, 1, 1, 0) @@ -77,7 +79,9 @@ class SwitchComputer(nn.Module): xformed.append(torch.zeros_like(xformed[0])) multiplexer = self.filter_conv(x) - for block in self.blocks: + for block in self.reduction_blocks: + multiplexer = block.forward(multiplexer) + for block in self.processing_blocks: multiplexer = block.forward(multiplexer) multiplexer = self.proc_switch_conv(multiplexer) multiplexer = self.final_switch_conv.forward(multiplexer) @@ -92,10 +96,10 @@ class SwitchComputer(nn.Module): class SwitchedResidualGenerator(nn.Module): def __init__(self, switch_filters, initial_temp=20, final_temperature_step=50000): super(SwitchedResidualGenerator, self).__init__() - self.switch1 = SwitchComputer(3, switch_filters, functools.partial(ResidualBranch, 3, 3, kernel_size=7, depth=3), 4, 4, initial_temp) - self.switch2 = SwitchComputer(3, switch_filters, functools.partial(ResidualBranch, 3, 3, kernel_size=5, depth=3), 8, 3, initial_temp) - self.switch3 = SwitchComputer(3, switch_filters, functools.partial(ResidualBranch, 3, 3, kernel_size=3, depth=3), 16, 2, initial_temp) - self.switch4 = SwitchComputer(3, switch_filters, functools.partial(ResidualBranch, 3, 3, kernel_size=3, depth=2), 32, 1, initial_temp) + self.switch1 = SwitchComputer(3, switch_filters, functools.partial(ResidualBranch, 3, 3, kernel_size=7, depth=3), 4, 4, 0, initial_temp) + self.switch2 = SwitchComputer(3, switch_filters, functools.partial(ResidualBranch, 3, 3, kernel_size=5, depth=3), 8, 3, 0, initial_temp) + self.switch3 = SwitchComputer(3, switch_filters, functools.partial(ResidualBranch, 3, 3, kernel_size=3, depth=3), 16, 2, 1, initial_temp) + self.switch4 = SwitchComputer(3, switch_filters * 2, functools.partial(ResidualBranch, 3, 3, kernel_size=3, depth=2), 32, 1, 2, initial_temp) initialize_weights([self.switch1, self.switch2, self.switch3, self.switch4], 1) # Initialize the transforms with a lesser weight, since they are repeatedly added on to the resultant image. initialize_weights([self.switch1.transforms, self.switch2.transforms, self.switch3.transforms, self.switch4.transforms], .05) @@ -155,11 +159,11 @@ class SwitchedResidualGenerator(nn.Module): class ConfigurableSwitchedResidualGenerator(nn.Module): - def __init__(self, switch_filters, switch_depths, trans_counts, trans_kernel_sizes, trans_layers, initial_temp=20, final_temperature_step=50000): + def __init__(self, switch_filters, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes, trans_layers, initial_temp=20, final_temperature_step=50000): super(ConfigurableSwitchedResidualGenerator, self).__init__() switches = [] - for filters, depth, trans_count, kernel, layers in zip(switch_filters, switch_depths, trans_counts, trans_kernel_sizes, trans_layers): - switches.append(SwitchComputer(3, filters, functools.partial(ResidualBranch, 3, 3, kernel_size=kernel, depth=layers), trans_count, depth, initial_temp)) + for filters, sw_reduce, sw_proc, trans_count, kernel, layers in zip(switch_filters, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes, trans_layers): + switches.append(SwitchComputer(3, filters, functools.partial(ResidualBranch, 3, 3, kernel_size=kernel, depth=layers), trans_count, sw_reduce, sw_proc, initial_temp)) initialize_weights(switches, 1) # Initialize the transforms with a lesser weight, since they are repeatedly added on to the resultant image. initialize_weights([s.transforms for s in switches], .05) diff --git a/codes/models/networks.py b/codes/models/networks.py index 6c70fcef..807f1478 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -73,7 +73,8 @@ def define_G(opt, net_key='network_G'): netG = SwitchedGen_arch.SwitchedResidualGenerator(switch_filters=opt_net['nf'], initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step']) elif which_model == "ConfigurableSwitchedResidualGenerator": - netG = SwitchedGen_arch.ConfigurableSwitchedResidualGenerator(switch_filters=opt_net['switch_filters'], switch_depths=opt_net['switch_depths'], trans_counts=opt_net['trans_counts'], + netG = SwitchedGen_arch.ConfigurableSwitchedResidualGenerator(switch_filters=opt_net['switch_filters'], switch_reductions=opt_net['switch_reductions'], + switch_processing_layers=opt_net['switch_processing_layers'], trans_counts=opt_net['trans_counts'], trans_kernel_sizes=opt_net['trans_kernel_sizes'], trans_layers=opt_net['trans_layers'], initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'])