From 407224eba15d0d963a97c37d8bb70b687db28337 Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 25 Jun 2020 18:17:05 -0600 Subject: [PATCH] Re-work SwitchedResgen2 Got rid of the converged multiplexer bases but kept the configurable architecture. The new multiplexers look a lot like the old one. Took some queues from the transformer architecture: translate image to a higher filter-space and stay there for the duration of the models computation. Also perform convs after each switch to allow the model to anneal issues that arise. --- .../archs/SwitchedResidualGenerator_arch.py | 55 +++++++++++++++---- 1 file changed, 44 insertions(+), 11 deletions(-) diff --git a/codes/models/archs/SwitchedResidualGenerator_arch.py b/codes/models/archs/SwitchedResidualGenerator_arch.py index cf8780a1..ea78214c 100644 --- a/codes/models/archs/SwitchedResidualGenerator_arch.py +++ b/codes/models/archs/SwitchedResidualGenerator_arch.py @@ -175,9 +175,32 @@ class ConfigurableSwitchComputer(nn.Module): self.switch.set_attention_temperature(temp) -class ResidualBasisMultiplexerBase(nn.Module): +class ConvBasisMultiplexer(nn.Module): + def __init__(self, input_channels, base_filters, growth, reductions, processing_depth, multiplexer_channels, use_bn=True): + super(ConvBasisMultiplexer, self).__init__() + self.filter_conv = ConvBnLelu(input_channels, base_filters) + self.reduction_blocks = nn.Sequential(OrderedDict([('block%i:' % (i,), HalvingProcessingBlock(base_filters * 2 ** i)) for i in range(reductions)])) + reduction_filters = base_filters * 2 ** reductions + self.processing_blocks, self.output_filter_count = create_sequential_growing_processing_block(reduction_filters, growth, processing_depth) + + gap = self.output_filter_count - multiplexer_channels + self.cbl1 = ConvBnLelu(self.output_filter_count, self.output_filter_count - (gap // 4), bn=use_bn) + self.cbl2 = ConvBnLelu(self.output_filter_count - (gap // 4), self.output_filter_count - (gap // 2), bn=use_bn) + self.cbl3 = ConvBnLelu(self.output_filter_count - (gap // 2), multiplexer_channels) + + def forward(self, x): + x = self.filter_conv(x) + x = self.reduction_blocks(x) + x = self.processing_blocks(x) + x = self.cbl1(x) + x = self.cbl2(x) + x = self.cbl3(x) + return x + + +class ConvBasisMultiplexerBase(nn.Module): def __init__(self, input_channels, base_filters, growth, reductions, processing_depth): - super(ResidualBasisMultiplexerBase, self).__init__() + super(ConvBasisMultiplexerBase, self).__init__() self.filter_conv = ConvBnLelu(input_channels, base_filters) self.reduction_blocks = nn.Sequential(OrderedDict([('block%i:' % (i,), HalvingProcessingBlock(base_filters * 2 ** i)) for i in range(reductions)])) reduction_filters = base_filters * 2 ** reductions @@ -190,9 +213,9 @@ class ResidualBasisMultiplexerBase(nn.Module): return x -class ResidualBasisMultiplexerLeaf(nn.Module): +class ConvBasisMultiplexerLeaf(nn.Module): def __init__(self, base, filters, multiplexer_channels, use_bn=False): - super(ResidualBasisMultiplexerLeaf, self).__init__() + super(ConvBasisMultiplexerLeaf, self).__init__() assert(filters > multiplexer_channels) gap = filters - multiplexer_channels assert(gap % 4 == 0) @@ -277,19 +300,24 @@ class ConfigurableSwitchedResidualGenerator(nn.Module): class ConfigurableSwitchedResidualGenerator2(nn.Module): def __init__(self, switch_filters, switch_growths, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes, - trans_layers, trans_filters_mid, initial_temp=20, final_temperature_step=50000, heightened_temp_min=1, + trans_layers, transformation_filters, initial_temp=20, final_temperature_step=50000, heightened_temp_min=1, heightened_final_step=50000, upsample_factor=1, enable_negative_transforms=False, add_scalable_noise_to_transforms=False): super(ConfigurableSwitchedResidualGenerator2, self).__init__() switches = [] - multiplexer_base = ResidualBasisMultiplexerBase(3, switch_filters[0], switch_growths[0], switch_reductions[0], switch_processing_layers[0]) - for trans_count, kernel, layers, mid_filters in zip(trans_counts, trans_kernel_sizes, trans_layers, trans_filters_mid): - leaf_fn = functools.partial(ResidualBasisMultiplexerLeaf, multiplexer_base, multiplexer_base.output_filter_count) - switches.append(ConfigurableSwitchComputer(leaf_fn, functools.partial(ResidualBranch, 3, mid_filters, 3, kernel_size=kernel, depth=layers), trans_count, initial_temp, enable_negative_transforms=enable_negative_transforms, add_scalable_noise_to_transforms=add_scalable_noise_to_transforms)) + post_switch_proc = [] + self.initial_conv = ConvBnLelu(3, transformation_filters, bn=False) + self.final_conv = ConvBnLelu(transformation_filters, 3, bn=False) + for filters, growth, sw_reduce, sw_proc, trans_count, kernel, layers in zip(switch_filters, switch_growths, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes, trans_layers): + multiplx_fn = functools.partial(ConvBasisMultiplexer, transformation_filters, filters, growth, sw_reduce, sw_proc, trans_count) + switches.append(ConfigurableSwitchComputer(multiplx_fn, functools.partial(ResidualBranch, transformation_filters, transformation_filters, transformation_filters, kernel_size=kernel, depth=layers), trans_count, initial_temp, enable_negative_transforms=enable_negative_transforms, add_scalable_noise_to_transforms=add_scalable_noise_to_transforms)) + post_switch_proc.append(ConvBnLelu(transformation_filters, transformation_filters, bn=False)) initialize_weights(switches, 1) # Initialize the transforms with a lesser weight, since they are repeatedly added on to the resultant image. initialize_weights([s.transforms for s in switches], .2 / len(switches)) self.switches = nn.ModuleList(switches) + initialize_weights([p for p in post_switch_proc], .01) + self.post_switch_convs = nn.ModuleList(post_switch_proc) self.transformation_counts = trans_counts self.init_temperature = initial_temp self.final_temperature_step = final_temperature_step @@ -304,11 +332,16 @@ class ConfigurableSwitchedResidualGenerator2(nn.Module): if self.upsample_factor > 1: x = F.interpolate(x, scale_factor=self.upsample_factor, mode="nearest") + x = self.initial_conv(x) + self.attentions = [] - for i, sw in enumerate(self.switches): + for i, (sw, conv) in enumerate(zip(self.switches, self.post_switch_convs)): sw_out, att = sw.forward(x, True) - x = x + sw_out self.attentions.append(att) + x = x + sw_out + x = x + conv(x) + + x = self.final_conv(x) return x, def set_temperature(self, temp):