Add growth channel to switch_growths for flat networks
This commit is contained in:
parent
3b81712c49
commit
68bcab03ae
|
@ -3,6 +3,7 @@ from torch import nn
|
||||||
from switched_conv import BareConvSwitch, compute_attention_specificity
|
from switched_conv import BareConvSwitch, compute_attention_specificity
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import functools
|
import functools
|
||||||
|
from collections import OrderedDict
|
||||||
from models.archs.arch_util import initialize_weights
|
from models.archs.arch_util import initialize_weights
|
||||||
from switched_conv_util import save_attention_to_image
|
from switched_conv_util import save_attention_to_image
|
||||||
|
|
||||||
|
@ -61,13 +62,24 @@ class HalvingProcessingBlock(nn.Module):
|
||||||
return self.bnconv2(x)
|
return self.bnconv2(x)
|
||||||
|
|
||||||
|
|
||||||
|
# Creates a nested series of convolutional blocks. Each block processes the input data in-place and adds
|
||||||
|
# filter_growth filters. Return is (nn.Sequential, ending_filters)
|
||||||
|
def create_sequential_growing_processing_block(filters_init, filter_growth, num_convs):
|
||||||
|
convs = []
|
||||||
|
current_filters = filters_init
|
||||||
|
for i in range(num_convs):
|
||||||
|
convs.append(ConvBnLelu(current_filters, current_filters + filter_growth))
|
||||||
|
current_filters += filter_growth
|
||||||
|
return nn.Sequential(*convs), current_filters
|
||||||
|
|
||||||
|
|
||||||
class SwitchComputer(nn.Module):
|
class SwitchComputer(nn.Module):
|
||||||
def __init__(self, channels_in, filters, transform_block, transform_count, reduction_blocks, processing_blocks=0, init_temp=20):
|
def __init__(self, channels_in, filters, growth, transform_block, transform_count, reduction_blocks, processing_blocks=0, init_temp=20):
|
||||||
super(SwitchComputer, self).__init__()
|
super(SwitchComputer, self).__init__()
|
||||||
self.filter_conv = ConvBnLelu(channels_in, filters)
|
self.filter_conv = ConvBnLelu(channels_in, filters)
|
||||||
self.reduction_blocks = nn.ModuleList([HalvingProcessingBlock(filters * 2 ** i) for i in range(reduction_blocks)])
|
self.reduction_blocks = nn.ModuleList([HalvingProcessingBlock(filters * 2 ** i) for i in range(reduction_blocks)])
|
||||||
final_filters = filters * 2 ** reduction_blocks
|
final_filters = filters * 2 ** reduction_blocks
|
||||||
self.processing_blocks = nn.ModuleList([ConvBnLelu(final_filters, final_filters) for i in range(processing_blocks)])
|
self.processing_blocks, final_filters = create_sequential_growing_processing_block(final_filters, growth, processing_blocks)
|
||||||
proc_block_filters = max(final_filters // 2, transform_count)
|
proc_block_filters = max(final_filters // 2, transform_count)
|
||||||
self.proc_switch_conv = ConvBnLelu(final_filters, proc_block_filters)
|
self.proc_switch_conv = ConvBnLelu(final_filters, proc_block_filters)
|
||||||
self.final_switch_conv = nn.Conv2d(proc_block_filters, transform_count, 1, 1, 0)
|
self.final_switch_conv = nn.Conv2d(proc_block_filters, transform_count, 1, 1, 0)
|
||||||
|
@ -104,13 +116,13 @@ class SwitchComputer(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class ConfigurableSwitchedResidualGenerator(nn.Module):
|
class ConfigurableSwitchedResidualGenerator(nn.Module):
|
||||||
def __init__(self, switch_filters, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes,
|
def __init__(self, switch_filters, switch_growths, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes,
|
||||||
trans_layers, trans_filters_mid, initial_temp=20, final_temperature_step=50000, heightened_temp_min=1,
|
trans_layers, trans_filters_mid, initial_temp=20, final_temperature_step=50000, heightened_temp_min=1,
|
||||||
heightened_final_step=50000, upsample_factor=1):
|
heightened_final_step=50000, upsample_factor=1):
|
||||||
super(ConfigurableSwitchedResidualGenerator, self).__init__()
|
super(ConfigurableSwitchedResidualGenerator, self).__init__()
|
||||||
switches = []
|
switches = []
|
||||||
for filters, sw_reduce, sw_proc, trans_count, kernel, layers, mid_filters in zip(switch_filters, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes, trans_layers, trans_filters_mid):
|
for filters, growth, sw_reduce, sw_proc, trans_count, kernel, layers, mid_filters in zip(switch_filters, switch_growths, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes, trans_layers, trans_filters_mid):
|
||||||
switches.append(SwitchComputer(3, filters, functools.partial(ResidualBranch, 3, mid_filters, 3, kernel_size=kernel, depth=layers), trans_count, sw_reduce, sw_proc, initial_temp))
|
switches.append(SwitchComputer(3, filters, growth, functools.partial(ResidualBranch, 3, mid_filters, 3, kernel_size=kernel, depth=layers), trans_count, sw_reduce, sw_proc, initial_temp))
|
||||||
initialize_weights(switches, 1)
|
initialize_weights(switches, 1)
|
||||||
# Initialize the transforms with a lesser weight, since they are repeatedly added on to the resultant image.
|
# Initialize the transforms with a lesser weight, since they are repeatedly added on to the resultant image.
|
||||||
initialize_weights([s.transforms for s in switches], .2 / len(switches))
|
initialize_weights([s.transforms for s in switches], .2 / len(switches))
|
||||||
|
|
|
@ -63,7 +63,8 @@ def define_G(opt, net_key='network_G'):
|
||||||
final_temperature_step=opt_net['temperature_final_step'])
|
final_temperature_step=opt_net['temperature_final_step'])
|
||||||
netG = RRDBNet_arch.PixShuffleRRDB(nf=opt_net['nf'], nb=opt_net['nb'], gc=opt_net['gc'], scale=scale, rrdb_block_f=block_f)
|
netG = RRDBNet_arch.PixShuffleRRDB(nf=opt_net['nf'], nb=opt_net['nb'], gc=opt_net['gc'], scale=scale, rrdb_block_f=block_f)
|
||||||
elif which_model == "ConfigurableSwitchedResidualGenerator":
|
elif which_model == "ConfigurableSwitchedResidualGenerator":
|
||||||
netG = SwitchedGen_arch.ConfigurableSwitchedResidualGenerator(switch_filters=opt_net['switch_filters'], switch_reductions=opt_net['switch_reductions'],
|
netG = SwitchedGen_arch.ConfigurableSwitchedResidualGenerator(switch_filters=opt_net['switch_filters'], switch_growths=opt_net['switch_growths'],
|
||||||
|
switch_reductions=opt_net['switch_reductions'],
|
||||||
switch_processing_layers=opt_net['switch_processing_layers'], trans_counts=opt_net['trans_counts'],
|
switch_processing_layers=opt_net['switch_processing_layers'], trans_counts=opt_net['trans_counts'],
|
||||||
trans_kernel_sizes=opt_net['trans_kernel_sizes'], trans_layers=opt_net['trans_layers'],
|
trans_kernel_sizes=opt_net['trans_kernel_sizes'], trans_layers=opt_net['trans_layers'],
|
||||||
trans_filters_mid=opt_net['trans_filters_mid'],
|
trans_filters_mid=opt_net['trans_filters_mid'],
|
||||||
|
|
Loading…
Reference in New Issue
Block a user