From 68bcab03ae37507a0f999ecc7a3045a009e3af38 Mon Sep 17 00:00:00 2001 From: James Betker Date: Mon, 22 Jun 2020 10:40:16 -0600 Subject: [PATCH] Add growth channel to switch_growths for flat networks --- .../archs/SwitchedResidualGenerator_arch.py | 22 ++++++++++++++----- codes/models/networks.py | 3 ++- 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/codes/models/archs/SwitchedResidualGenerator_arch.py b/codes/models/archs/SwitchedResidualGenerator_arch.py index c35612ce..1b43173c 100644 --- a/codes/models/archs/SwitchedResidualGenerator_arch.py +++ b/codes/models/archs/SwitchedResidualGenerator_arch.py @@ -3,6 +3,7 @@ from torch import nn from switched_conv import BareConvSwitch, compute_attention_specificity import torch.nn.functional as F import functools +from collections import OrderedDict from models.archs.arch_util import initialize_weights from switched_conv_util import save_attention_to_image @@ -61,13 +62,24 @@ class HalvingProcessingBlock(nn.Module): return self.bnconv2(x) +# Creates a nested series of convolutional blocks. Each block processes the input data in-place and adds +# filter_growth filters. Return is (nn.Sequential, ending_filters) +def create_sequential_growing_processing_block(filters_init, filter_growth, num_convs): + convs = [] + current_filters = filters_init + for i in range(num_convs): + convs.append(ConvBnLelu(current_filters, current_filters + filter_growth)) + current_filters += filter_growth + return nn.Sequential(*convs), current_filters + + class SwitchComputer(nn.Module): - def __init__(self, channels_in, filters, transform_block, transform_count, reduction_blocks, processing_blocks=0, init_temp=20): + def __init__(self, channels_in, filters, growth, transform_block, transform_count, reduction_blocks, processing_blocks=0, init_temp=20): super(SwitchComputer, self).__init__() self.filter_conv = ConvBnLelu(channels_in, filters) self.reduction_blocks = nn.ModuleList([HalvingProcessingBlock(filters * 2 ** i) for i in range(reduction_blocks)]) final_filters = filters * 2 ** reduction_blocks - self.processing_blocks = nn.ModuleList([ConvBnLelu(final_filters, final_filters) for i in range(processing_blocks)]) + self.processing_blocks, final_filters = create_sequential_growing_processing_block(final_filters, growth, processing_blocks) proc_block_filters = max(final_filters // 2, transform_count) self.proc_switch_conv = ConvBnLelu(final_filters, proc_block_filters) self.final_switch_conv = nn.Conv2d(proc_block_filters, transform_count, 1, 1, 0) @@ -104,13 +116,13 @@ class SwitchComputer(nn.Module): class ConfigurableSwitchedResidualGenerator(nn.Module): - def __init__(self, switch_filters, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes, + def __init__(self, switch_filters, switch_growths, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes, trans_layers, trans_filters_mid, initial_temp=20, final_temperature_step=50000, heightened_temp_min=1, heightened_final_step=50000, upsample_factor=1): super(ConfigurableSwitchedResidualGenerator, self).__init__() switches = [] - for filters, sw_reduce, sw_proc, trans_count, kernel, layers, mid_filters in zip(switch_filters, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes, trans_layers, trans_filters_mid): - switches.append(SwitchComputer(3, filters, functools.partial(ResidualBranch, 3, mid_filters, 3, kernel_size=kernel, depth=layers), trans_count, sw_reduce, sw_proc, initial_temp)) + for filters, growth, sw_reduce, sw_proc, trans_count, kernel, layers, mid_filters in zip(switch_filters, switch_growths, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes, trans_layers, trans_filters_mid): + switches.append(SwitchComputer(3, filters, growth, functools.partial(ResidualBranch, 3, mid_filters, 3, kernel_size=kernel, depth=layers), trans_count, sw_reduce, sw_proc, initial_temp)) initialize_weights(switches, 1) # Initialize the transforms with a lesser weight, since they are repeatedly added on to the resultant image. initialize_weights([s.transforms for s in switches], .2 / len(switches)) diff --git a/codes/models/networks.py b/codes/models/networks.py index c527b6a7..e126b0f3 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -63,7 +63,8 @@ def define_G(opt, net_key='network_G'): final_temperature_step=opt_net['temperature_final_step']) netG = RRDBNet_arch.PixShuffleRRDB(nf=opt_net['nf'], nb=opt_net['nb'], gc=opt_net['gc'], scale=scale, rrdb_block_f=block_f) elif which_model == "ConfigurableSwitchedResidualGenerator": - netG = SwitchedGen_arch.ConfigurableSwitchedResidualGenerator(switch_filters=opt_net['switch_filters'], switch_reductions=opt_net['switch_reductions'], + netG = SwitchedGen_arch.ConfigurableSwitchedResidualGenerator(switch_filters=opt_net['switch_filters'], switch_growths=opt_net['switch_growths'], + switch_reductions=opt_net['switch_reductions'], switch_processing_layers=opt_net['switch_processing_layers'], trans_counts=opt_net['trans_counts'], trans_kernel_sizes=opt_net['trans_kernel_sizes'], trans_layers=opt_net['trans_layers'], trans_filters_mid=opt_net['trans_filters_mid'],