Mods to SwitchedResidualGenerator_arch
- Increased processing for high-resolution switches - Do stride=2 first in HalvingProcessingBlock
This commit is contained in:
parent
70c764b9d4
commit
2def96203e
|
@ -43,12 +43,13 @@ class ResidualBranch(nn.Module):
|
|||
return x * self.scale + self.bias
|
||||
|
||||
|
||||
# VGG-style layer with Conv->BN->Activation->Conv(stride2)->BN->Activation
|
||||
# VGG-style layer with Conv(stride2)->BN->Activation->Conv->BN->Activation
|
||||
# Doubles the input filter count.
|
||||
class HalvingProcessingBlock(nn.Module):
|
||||
def __init__(self, filters):
|
||||
super(HalvingProcessingBlock, self).__init__()
|
||||
self.bnconv1 = ConvBnLelu(filters, filters)
|
||||
self.bnconv2 = ConvBnLelu(filters, filters * 2, stride=2)
|
||||
self.bnconv1 = ConvBnLelu(filters, filters * 2, stride=2)
|
||||
self.bnconv2 = ConvBnLelu(filters * 2, filters * 2)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.bnconv1(x)
|
||||
|
@ -56,11 +57,12 @@ class HalvingProcessingBlock(nn.Module):
|
|||
|
||||
|
||||
class SwitchComputer(nn.Module):
|
||||
def __init__(self, channels_in, filters, transform_block, transform_count, reductions, init_temp=20):
|
||||
def __init__(self, channels_in, filters, transform_block, transform_count, reduction_blocks, processing_blocks=0, init_temp=20):
|
||||
super(SwitchComputer, self).__init__()
|
||||
self.filter_conv = ConvBnLelu(channels_in, filters)
|
||||
self.blocks = nn.ModuleList([HalvingProcessingBlock(filters * 2 ** i) for i in range(reductions)])
|
||||
final_filters = filters * 2 ** reductions
|
||||
self.reduction_blocks = nn.ModuleList([HalvingProcessingBlock(filters * 2 ** i) for i in range(reduction_blocks)])
|
||||
final_filters = filters * 2 ** reduction_blocks
|
||||
self.processing_blocks = nn.ModuleList([ConvBnLelu(final_filters, final_filters) for i in range(processing_blocks)])
|
||||
proc_block_filters = max(final_filters // 2, transform_count)
|
||||
self.proc_switch_conv = ConvBnLelu(final_filters, proc_block_filters)
|
||||
self.final_switch_conv = nn.Conv2d(proc_block_filters, transform_count, 1, 1, 0)
|
||||
|
@ -77,7 +79,9 @@ class SwitchComputer(nn.Module):
|
|||
xformed.append(torch.zeros_like(xformed[0]))
|
||||
|
||||
multiplexer = self.filter_conv(x)
|
||||
for block in self.blocks:
|
||||
for block in self.reduction_blocks:
|
||||
multiplexer = block.forward(multiplexer)
|
||||
for block in self.processing_blocks:
|
||||
multiplexer = block.forward(multiplexer)
|
||||
multiplexer = self.proc_switch_conv(multiplexer)
|
||||
multiplexer = self.final_switch_conv.forward(multiplexer)
|
||||
|
@ -92,10 +96,10 @@ class SwitchComputer(nn.Module):
|
|||
class SwitchedResidualGenerator(nn.Module):
|
||||
def __init__(self, switch_filters, initial_temp=20, final_temperature_step=50000):
|
||||
super(SwitchedResidualGenerator, self).__init__()
|
||||
self.switch1 = SwitchComputer(3, switch_filters, functools.partial(ResidualBranch, 3, 3, kernel_size=7, depth=3), 4, 4, initial_temp)
|
||||
self.switch2 = SwitchComputer(3, switch_filters, functools.partial(ResidualBranch, 3, 3, kernel_size=5, depth=3), 8, 3, initial_temp)
|
||||
self.switch3 = SwitchComputer(3, switch_filters, functools.partial(ResidualBranch, 3, 3, kernel_size=3, depth=3), 16, 2, initial_temp)
|
||||
self.switch4 = SwitchComputer(3, switch_filters, functools.partial(ResidualBranch, 3, 3, kernel_size=3, depth=2), 32, 1, initial_temp)
|
||||
self.switch1 = SwitchComputer(3, switch_filters, functools.partial(ResidualBranch, 3, 3, kernel_size=7, depth=3), 4, 4, 0, initial_temp)
|
||||
self.switch2 = SwitchComputer(3, switch_filters, functools.partial(ResidualBranch, 3, 3, kernel_size=5, depth=3), 8, 3, 0, initial_temp)
|
||||
self.switch3 = SwitchComputer(3, switch_filters, functools.partial(ResidualBranch, 3, 3, kernel_size=3, depth=3), 16, 2, 1, initial_temp)
|
||||
self.switch4 = SwitchComputer(3, switch_filters * 2, functools.partial(ResidualBranch, 3, 3, kernel_size=3, depth=2), 32, 1, 2, initial_temp)
|
||||
initialize_weights([self.switch1, self.switch2, self.switch3, self.switch4], 1)
|
||||
# Initialize the transforms with a lesser weight, since they are repeatedly added on to the resultant image.
|
||||
initialize_weights([self.switch1.transforms, self.switch2.transforms, self.switch3.transforms, self.switch4.transforms], .05)
|
||||
|
@ -155,11 +159,11 @@ class SwitchedResidualGenerator(nn.Module):
|
|||
|
||||
|
||||
class ConfigurableSwitchedResidualGenerator(nn.Module):
|
||||
def __init__(self, switch_filters, switch_depths, trans_counts, trans_kernel_sizes, trans_layers, initial_temp=20, final_temperature_step=50000):
|
||||
def __init__(self, switch_filters, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes, trans_layers, initial_temp=20, final_temperature_step=50000):
|
||||
super(ConfigurableSwitchedResidualGenerator, self).__init__()
|
||||
switches = []
|
||||
for filters, depth, trans_count, kernel, layers in zip(switch_filters, switch_depths, trans_counts, trans_kernel_sizes, trans_layers):
|
||||
switches.append(SwitchComputer(3, filters, functools.partial(ResidualBranch, 3, 3, kernel_size=kernel, depth=layers), trans_count, depth, initial_temp))
|
||||
for filters, sw_reduce, sw_proc, trans_count, kernel, layers in zip(switch_filters, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes, trans_layers):
|
||||
switches.append(SwitchComputer(3, filters, functools.partial(ResidualBranch, 3, 3, kernel_size=kernel, depth=layers), trans_count, sw_reduce, sw_proc, initial_temp))
|
||||
initialize_weights(switches, 1)
|
||||
# Initialize the transforms with a lesser weight, since they are repeatedly added on to the resultant image.
|
||||
initialize_weights([s.transforms for s in switches], .05)
|
||||
|
|
|
@ -73,7 +73,8 @@ def define_G(opt, net_key='network_G'):
|
|||
netG = SwitchedGen_arch.SwitchedResidualGenerator(switch_filters=opt_net['nf'], initial_temp=opt_net['temperature'],
|
||||
final_temperature_step=opt_net['temperature_final_step'])
|
||||
elif which_model == "ConfigurableSwitchedResidualGenerator":
|
||||
netG = SwitchedGen_arch.ConfigurableSwitchedResidualGenerator(switch_filters=opt_net['switch_filters'], switch_depths=opt_net['switch_depths'], trans_counts=opt_net['trans_counts'],
|
||||
netG = SwitchedGen_arch.ConfigurableSwitchedResidualGenerator(switch_filters=opt_net['switch_filters'], switch_reductions=opt_net['switch_reductions'],
|
||||
switch_processing_layers=opt_net['switch_processing_layers'], trans_counts=opt_net['trans_counts'],
|
||||
trans_kernel_sizes=opt_net['trans_kernel_sizes'], trans_layers=opt_net['trans_layers'],
|
||||
initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'])
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user