Add growth channel to switch_growths for flat networks

This commit is contained in:
James Betker 2020-06-22 10:40:16 -06:00
parent 3b81712c49
commit 68bcab03ae
2 changed files with 19 additions and 6 deletions

View File

@ -3,6 +3,7 @@ from torch import nn
from switched_conv import BareConvSwitch, compute_attention_specificity from switched_conv import BareConvSwitch, compute_attention_specificity
import torch.nn.functional as F import torch.nn.functional as F
import functools import functools
from collections import OrderedDict
from models.archs.arch_util import initialize_weights from models.archs.arch_util import initialize_weights
from switched_conv_util import save_attention_to_image from switched_conv_util import save_attention_to_image
@ -61,13 +62,24 @@ class HalvingProcessingBlock(nn.Module):
return self.bnconv2(x) return self.bnconv2(x)
# Creates a nested series of convolutional blocks. Each block processes the input data in-place and adds
# filter_growth filters. Return is (nn.Sequential, ending_filters)
def create_sequential_growing_processing_block(filters_init, filter_growth, num_convs):
convs = []
current_filters = filters_init
for i in range(num_convs):
convs.append(ConvBnLelu(current_filters, current_filters + filter_growth))
current_filters += filter_growth
return nn.Sequential(*convs), current_filters
class SwitchComputer(nn.Module): class SwitchComputer(nn.Module):
def __init__(self, channels_in, filters, transform_block, transform_count, reduction_blocks, processing_blocks=0, init_temp=20): def __init__(self, channels_in, filters, growth, transform_block, transform_count, reduction_blocks, processing_blocks=0, init_temp=20):
super(SwitchComputer, self).__init__() super(SwitchComputer, self).__init__()
self.filter_conv = ConvBnLelu(channels_in, filters) self.filter_conv = ConvBnLelu(channels_in, filters)
self.reduction_blocks = nn.ModuleList([HalvingProcessingBlock(filters * 2 ** i) for i in range(reduction_blocks)]) self.reduction_blocks = nn.ModuleList([HalvingProcessingBlock(filters * 2 ** i) for i in range(reduction_blocks)])
final_filters = filters * 2 ** reduction_blocks final_filters = filters * 2 ** reduction_blocks
self.processing_blocks = nn.ModuleList([ConvBnLelu(final_filters, final_filters) for i in range(processing_blocks)]) self.processing_blocks, final_filters = create_sequential_growing_processing_block(final_filters, growth, processing_blocks)
proc_block_filters = max(final_filters // 2, transform_count) proc_block_filters = max(final_filters // 2, transform_count)
self.proc_switch_conv = ConvBnLelu(final_filters, proc_block_filters) self.proc_switch_conv = ConvBnLelu(final_filters, proc_block_filters)
self.final_switch_conv = nn.Conv2d(proc_block_filters, transform_count, 1, 1, 0) self.final_switch_conv = nn.Conv2d(proc_block_filters, transform_count, 1, 1, 0)
@ -104,13 +116,13 @@ class SwitchComputer(nn.Module):
class ConfigurableSwitchedResidualGenerator(nn.Module): class ConfigurableSwitchedResidualGenerator(nn.Module):
def __init__(self, switch_filters, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes, def __init__(self, switch_filters, switch_growths, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes,
trans_layers, trans_filters_mid, initial_temp=20, final_temperature_step=50000, heightened_temp_min=1, trans_layers, trans_filters_mid, initial_temp=20, final_temperature_step=50000, heightened_temp_min=1,
heightened_final_step=50000, upsample_factor=1): heightened_final_step=50000, upsample_factor=1):
super(ConfigurableSwitchedResidualGenerator, self).__init__() super(ConfigurableSwitchedResidualGenerator, self).__init__()
switches = [] switches = []
for filters, sw_reduce, sw_proc, trans_count, kernel, layers, mid_filters in zip(switch_filters, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes, trans_layers, trans_filters_mid): for filters, growth, sw_reduce, sw_proc, trans_count, kernel, layers, mid_filters in zip(switch_filters, switch_growths, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes, trans_layers, trans_filters_mid):
switches.append(SwitchComputer(3, filters, functools.partial(ResidualBranch, 3, mid_filters, 3, kernel_size=kernel, depth=layers), trans_count, sw_reduce, sw_proc, initial_temp)) switches.append(SwitchComputer(3, filters, growth, functools.partial(ResidualBranch, 3, mid_filters, 3, kernel_size=kernel, depth=layers), trans_count, sw_reduce, sw_proc, initial_temp))
initialize_weights(switches, 1) initialize_weights(switches, 1)
# Initialize the transforms with a lesser weight, since they are repeatedly added on to the resultant image. # Initialize the transforms with a lesser weight, since they are repeatedly added on to the resultant image.
initialize_weights([s.transforms for s in switches], .2 / len(switches)) initialize_weights([s.transforms for s in switches], .2 / len(switches))

View File

@ -63,7 +63,8 @@ def define_G(opt, net_key='network_G'):
final_temperature_step=opt_net['temperature_final_step']) final_temperature_step=opt_net['temperature_final_step'])
netG = RRDBNet_arch.PixShuffleRRDB(nf=opt_net['nf'], nb=opt_net['nb'], gc=opt_net['gc'], scale=scale, rrdb_block_f=block_f) netG = RRDBNet_arch.PixShuffleRRDB(nf=opt_net['nf'], nb=opt_net['nb'], gc=opt_net['gc'], scale=scale, rrdb_block_f=block_f)
elif which_model == "ConfigurableSwitchedResidualGenerator": elif which_model == "ConfigurableSwitchedResidualGenerator":
netG = SwitchedGen_arch.ConfigurableSwitchedResidualGenerator(switch_filters=opt_net['switch_filters'], switch_reductions=opt_net['switch_reductions'], netG = SwitchedGen_arch.ConfigurableSwitchedResidualGenerator(switch_filters=opt_net['switch_filters'], switch_growths=opt_net['switch_growths'],
switch_reductions=opt_net['switch_reductions'],
switch_processing_layers=opt_net['switch_processing_layers'], trans_counts=opt_net['trans_counts'], switch_processing_layers=opt_net['switch_processing_layers'], trans_counts=opt_net['trans_counts'],
trans_kernel_sizes=opt_net['trans_kernel_sizes'], trans_layers=opt_net['trans_layers'], trans_kernel_sizes=opt_net['trans_kernel_sizes'], trans_layers=opt_net['trans_layers'],
trans_filters_mid=opt_net['trans_filters_mid'], trans_filters_mid=opt_net['trans_filters_mid'],