Mods to SwitchedResidualGenerator_arch

- Increased processing for high-resolution switches
- Do stride=2 first in HalvingProcessingBlock
This commit is contained in:
James Betker 2020-06-16 14:19:12 -06:00
parent 70c764b9d4
commit 2def96203e
2 changed files with 20 additions and 15 deletions

View File

@ -43,12 +43,13 @@ class ResidualBranch(nn.Module):
return x * self.scale + self.bias
# VGG-style layer with Conv->BN->Activation->Conv(stride2)->BN->Activation
# VGG-style layer with Conv(stride2)->BN->Activation->Conv->BN->Activation
# Doubles the input filter count.
class HalvingProcessingBlock(nn.Module):
def __init__(self, filters):
super(HalvingProcessingBlock, self).__init__()
self.bnconv1 = ConvBnLelu(filters, filters)
self.bnconv2 = ConvBnLelu(filters, filters * 2, stride=2)
self.bnconv1 = ConvBnLelu(filters, filters * 2, stride=2)
self.bnconv2 = ConvBnLelu(filters * 2, filters * 2)
def forward(self, x):
x = self.bnconv1(x)
@ -56,11 +57,12 @@ class HalvingProcessingBlock(nn.Module):
class SwitchComputer(nn.Module):
def __init__(self, channels_in, filters, transform_block, transform_count, reductions, init_temp=20):
def __init__(self, channels_in, filters, transform_block, transform_count, reduction_blocks, processing_blocks=0, init_temp=20):
super(SwitchComputer, self).__init__()
self.filter_conv = ConvBnLelu(channels_in, filters)
self.blocks = nn.ModuleList([HalvingProcessingBlock(filters * 2 ** i) for i in range(reductions)])
final_filters = filters * 2 ** reductions
self.reduction_blocks = nn.ModuleList([HalvingProcessingBlock(filters * 2 ** i) for i in range(reduction_blocks)])
final_filters = filters * 2 ** reduction_blocks
self.processing_blocks = nn.ModuleList([ConvBnLelu(final_filters, final_filters) for i in range(processing_blocks)])
proc_block_filters = max(final_filters // 2, transform_count)
self.proc_switch_conv = ConvBnLelu(final_filters, proc_block_filters)
self.final_switch_conv = nn.Conv2d(proc_block_filters, transform_count, 1, 1, 0)
@ -77,7 +79,9 @@ class SwitchComputer(nn.Module):
xformed.append(torch.zeros_like(xformed[0]))
multiplexer = self.filter_conv(x)
for block in self.blocks:
for block in self.reduction_blocks:
multiplexer = block.forward(multiplexer)
for block in self.processing_blocks:
multiplexer = block.forward(multiplexer)
multiplexer = self.proc_switch_conv(multiplexer)
multiplexer = self.final_switch_conv.forward(multiplexer)
@ -92,10 +96,10 @@ class SwitchComputer(nn.Module):
class SwitchedResidualGenerator(nn.Module):
def __init__(self, switch_filters, initial_temp=20, final_temperature_step=50000):
super(SwitchedResidualGenerator, self).__init__()
self.switch1 = SwitchComputer(3, switch_filters, functools.partial(ResidualBranch, 3, 3, kernel_size=7, depth=3), 4, 4, initial_temp)
self.switch2 = SwitchComputer(3, switch_filters, functools.partial(ResidualBranch, 3, 3, kernel_size=5, depth=3), 8, 3, initial_temp)
self.switch3 = SwitchComputer(3, switch_filters, functools.partial(ResidualBranch, 3, 3, kernel_size=3, depth=3), 16, 2, initial_temp)
self.switch4 = SwitchComputer(3, switch_filters, functools.partial(ResidualBranch, 3, 3, kernel_size=3, depth=2), 32, 1, initial_temp)
self.switch1 = SwitchComputer(3, switch_filters, functools.partial(ResidualBranch, 3, 3, kernel_size=7, depth=3), 4, 4, 0, initial_temp)
self.switch2 = SwitchComputer(3, switch_filters, functools.partial(ResidualBranch, 3, 3, kernel_size=5, depth=3), 8, 3, 0, initial_temp)
self.switch3 = SwitchComputer(3, switch_filters, functools.partial(ResidualBranch, 3, 3, kernel_size=3, depth=3), 16, 2, 1, initial_temp)
self.switch4 = SwitchComputer(3, switch_filters * 2, functools.partial(ResidualBranch, 3, 3, kernel_size=3, depth=2), 32, 1, 2, initial_temp)
initialize_weights([self.switch1, self.switch2, self.switch3, self.switch4], 1)
# Initialize the transforms with a lesser weight, since they are repeatedly added on to the resultant image.
initialize_weights([self.switch1.transforms, self.switch2.transforms, self.switch3.transforms, self.switch4.transforms], .05)
@ -155,11 +159,11 @@ class SwitchedResidualGenerator(nn.Module):
class ConfigurableSwitchedResidualGenerator(nn.Module):
def __init__(self, switch_filters, switch_depths, trans_counts, trans_kernel_sizes, trans_layers, initial_temp=20, final_temperature_step=50000):
def __init__(self, switch_filters, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes, trans_layers, initial_temp=20, final_temperature_step=50000):
super(ConfigurableSwitchedResidualGenerator, self).__init__()
switches = []
for filters, depth, trans_count, kernel, layers in zip(switch_filters, switch_depths, trans_counts, trans_kernel_sizes, trans_layers):
switches.append(SwitchComputer(3, filters, functools.partial(ResidualBranch, 3, 3, kernel_size=kernel, depth=layers), trans_count, depth, initial_temp))
for filters, sw_reduce, sw_proc, trans_count, kernel, layers in zip(switch_filters, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes, trans_layers):
switches.append(SwitchComputer(3, filters, functools.partial(ResidualBranch, 3, 3, kernel_size=kernel, depth=layers), trans_count, sw_reduce, sw_proc, initial_temp))
initialize_weights(switches, 1)
# Initialize the transforms with a lesser weight, since they are repeatedly added on to the resultant image.
initialize_weights([s.transforms for s in switches], .05)

View File

@ -73,7 +73,8 @@ def define_G(opt, net_key='network_G'):
netG = SwitchedGen_arch.SwitchedResidualGenerator(switch_filters=opt_net['nf'], initial_temp=opt_net['temperature'],
final_temperature_step=opt_net['temperature_final_step'])
elif which_model == "ConfigurableSwitchedResidualGenerator":
netG = SwitchedGen_arch.ConfigurableSwitchedResidualGenerator(switch_filters=opt_net['switch_filters'], switch_depths=opt_net['switch_depths'], trans_counts=opt_net['trans_counts'],
netG = SwitchedGen_arch.ConfigurableSwitchedResidualGenerator(switch_filters=opt_net['switch_filters'], switch_reductions=opt_net['switch_reductions'],
switch_processing_layers=opt_net['switch_processing_layers'], trans_counts=opt_net['trans_counts'],
trans_kernel_sizes=opt_net['trans_kernel_sizes'], trans_layers=opt_net['trans_layers'],
initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'])