From 0acad81035e4baebcd7b9da00e3a5bf7b5477f22 Mon Sep 17 00:00:00 2001 From: James Betker Date: Mon, 6 Jul 2020 22:40:40 -0600 Subject: [PATCH] More SRG2 adjustments.. --- .../archs/SwitchedResidualGenerator_arch.py | 86 +++++++++---------- 1 file changed, 42 insertions(+), 44 deletions(-) diff --git a/codes/models/archs/SwitchedResidualGenerator_arch.py b/codes/models/archs/SwitchedResidualGenerator_arch.py index 7ac786c8..c9bb9ea2 100644 --- a/codes/models/archs/SwitchedResidualGenerator_arch.py +++ b/codes/models/archs/SwitchedResidualGenerator_arch.py @@ -54,6 +54,46 @@ def create_sequential_growing_processing_block(filters_init, filter_growth, num_ return nn.Sequential(*convs), current_filters +class ConvBasisMultiplexer(nn.Module): + def __init__(self, input_channels, base_filters, growth, reductions, processing_depth, multiplexer_channels, use_bn=True): + super(ConvBasisMultiplexer, self).__init__() + self.filter_conv = ConvBnSilu(input_channels, base_filters, bias=True) + self.reduction_blocks = nn.Sequential(OrderedDict([('block%i:' % (i,), HalvingProcessingBlock(base_filters * 2 ** i)) for i in range(reductions)])) + reduction_filters = base_filters * 2 ** reductions + self.processing_blocks, self.output_filter_count = create_sequential_growing_processing_block(reduction_filters, growth, processing_depth) + + gap = self.output_filter_count - multiplexer_channels + # Hey silly - if you're going to interpolate later, do it here instead. Then add some processing layers to let the model adjust it properly. + self.cbl1 = ConvBnSilu(self.output_filter_count, self.output_filter_count - (gap // 2), bn=use_bn, bias=False) + self.cbl2 = ConvBnSilu(self.output_filter_count - (gap // 2), self.output_filter_count - (3 * gap // 4), bn=use_bn, bias=False) + self.cbl3 = ConvBnSilu(self.output_filter_count - (3 * gap // 4), multiplexer_channels, bias=True) + + def forward(self, x): + x = self.filter_conv(x) + x = self.reduction_blocks(x) + x = self.processing_blocks(x) + x = self.cbl1(x) + x = self.cbl2(x) + x = self.cbl3(x) + return x + + +class SpineNetMultiplexer(nn.Module): + def __init__(self, input_channels, transform_count): + super(SpineNetMultiplexer, self).__init__() + self.backbone = SpineNet('49', in_channels=input_channels) + self.rdc1 = ConvBnSilu(256, 128, kernel_size=3, bias=False) + self.rdc2 = ConvBnSilu(128, 64, kernel_size=3, bias=False) + self.rdc3 = ConvBnSilu(64, transform_count, bias=False, bn=False, relu=False) + + def forward(self, x): + spine = self.backbone(x) + feat = self.rdc1(spine[0]) + feat = self.rdc2(feat) + feat = self.rdc3(feat) + return feat + + class ConfigurableSwitchComputer(nn.Module): def __init__(self, base_filters, multiplexer_net, pre_transform_block, transform_block, transform_count, init_temp=20, enable_negative_transforms=False, add_scalable_noise_to_transforms=False, init_scalar=1): @@ -106,46 +146,6 @@ class ConfigurableSwitchComputer(nn.Module): self.switch.set_attention_temperature(temp) -class ConvBasisMultiplexer(nn.Module): - def __init__(self, input_channels, base_filters, growth, reductions, processing_depth, multiplexer_channels, use_bn=True): - super(ConvBasisMultiplexer, self).__init__() - self.filter_conv = ConvBnSilu(input_channels, base_filters, bias=True) - self.reduction_blocks = nn.Sequential(OrderedDict([('block%i:' % (i,), HalvingProcessingBlock(base_filters * 2 ** i)) for i in range(reductions)])) - reduction_filters = base_filters * 2 ** reductions - self.processing_blocks, self.output_filter_count = create_sequential_growing_processing_block(reduction_filters, growth, processing_depth) - - gap = self.output_filter_count - multiplexer_channels - # Hey silly - if you're going to interpolate later, do it here instead. Then add some processing layers to let the model adjust it properly. - self.cbl1 = ConvBnSilu(self.output_filter_count, self.output_filter_count - (gap // 2), bn=use_bn, bias=False) - self.cbl2 = ConvBnSilu(self.output_filter_count - (gap // 2), self.output_filter_count - (3 * gap // 4), bn=use_bn, bias=False) - self.cbl3 = ConvBnSilu(self.output_filter_count - (3 * gap // 4), multiplexer_channels, bias=True) - - def forward(self, x): - x = self.filter_conv(x) - x = self.reduction_blocks(x) - x = self.processing_blocks(x) - x = self.cbl1(x) - x = self.cbl2(x) - x = self.cbl3(x) - return x - - -class SpineNetMultiplexer(nn.Module): - def __init__(self, input_channels, transform_count): - super(SpineNetMultiplexer, self).__init__() - self.backbone = SpineNet('49', in_channels=input_channels) - self.rdc1 = ConvBnSilu(256, 128, kernel_size=3, bias=False) - self.rdc2 = ConvBnSilu(128, 64, kernel_size=3, bias=False) - self.rdc3 = ConvBnSilu(64, transform_count, bias=False, bn=False, relu=False) - - def forward(self, x): - spine = self.backbone(x) - feat = self.rdc1(spine[0]) - feat = self.rdc2(feat) - feat = self.rdc3(feat) - return feat - - class ConfigurableSwitchedResidualGenerator2(nn.Module): def __init__(self, switch_filters, switch_growths, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes, trans_layers, transformation_filters, initial_temp=20, final_temperature_step=50000, heightened_temp_min=1, @@ -165,7 +165,7 @@ class ConfigurableSwitchedResidualGenerator2(nn.Module): pre_transform_block=functools.partial(ConvBnLelu, transformation_filters, transformation_filters, bn=False, bias=False), transform_block=functools.partial(MultiConvBlock, transformation_filters, transformation_filters + growth, transformation_filters, kernel_size=kernel, depth=layers), transform_count=trans_count, init_temp=initial_temp, enable_negative_transforms=enable_negative_transforms, - add_scalable_noise_to_transforms=add_scalable_noise_to_transforms, init_scalar=.2)) + add_scalable_noise_to_transforms=add_scalable_noise_to_transforms, init_scalar=.1)) self.switches = nn.ModuleList(switches) self.transformation_counts = trans_counts @@ -181,11 +181,9 @@ class ConfigurableSwitchedResidualGenerator2(nn.Module): x = self.initial_conv(x) self.attentions = [] - swx = x for i, sw in enumerate(self.switches): - swx, att = sw.forward(swx, True) + x, att = sw.forward(x, True) self.attentions.append(att) - x = swx + self.sw_conv(x) x = self.upconv1(F.interpolate(x, scale_factor=2, mode="nearest")) if self.upsample_factor > 2: