From 3c31bea1ac2b91e2bd71b3ef68980ece99e0cc26 Mon Sep 17 00:00:00 2001 From: James Betker Date: Mon, 6 Jul 2020 22:22:29 -0600 Subject: [PATCH] SRG2 architectural changes --- .../archs/SwitchedResidualGenerator_arch.py | 32 +++++++++++-------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/codes/models/archs/SwitchedResidualGenerator_arch.py b/codes/models/archs/SwitchedResidualGenerator_arch.py index a147510d..459bb5dc 100644 --- a/codes/models/archs/SwitchedResidualGenerator_arch.py +++ b/codes/models/archs/SwitchedResidualGenerator_arch.py @@ -115,6 +115,7 @@ class ConvBasisMultiplexer(nn.Module): self.processing_blocks, self.output_filter_count = create_sequential_growing_processing_block(reduction_filters, growth, processing_depth) gap = self.output_filter_count - multiplexer_channels + # Hey silly - if you're going to interpolate later, do it here instead. Then add some processing layers to let the model adjust it properly. self.cbl1 = ConvBnSilu(self.output_filter_count, self.output_filter_count - (gap // 2), bn=use_bn, bias=False) self.cbl2 = ConvBnSilu(self.output_filter_count - (gap // 2), self.output_filter_count - (3 * gap // 4), bn=use_bn, bias=False) self.cbl3 = ConvBnSilu(self.output_filter_count - (3 * gap // 4), multiplexer_channels, bias=True) @@ -152,18 +153,19 @@ class ConfigurableSwitchedResidualGenerator2(nn.Module): add_scalable_noise_to_transforms=False): super(ConfigurableSwitchedResidualGenerator2, self).__init__() switches = [] - self.initial_conv = ConvBnLelu(3, transformation_filters, bn=False) - self.proc_conv = ConvBnLelu(transformation_filters, transformation_filters, bn=False) - self.final_conv = ConvBnLelu(transformation_filters, 3, bn=False, lelu=False) + self.initial_conv = ConvBnLelu(3, transformation_filters, bn=False, lelu=False, bias=True) + self.sw_conv = ConvBnLelu(transformation_filters, transformation_filters, lelu=False, bias=True) + self.upconv1 = ConvBnLelu(transformation_filters, transformation_filters, bn=False, biasd=True) + self.upconv2 = ConvBnLelu(transformation_filters, transformation_filters, bn=False, bias=True) + self.hr_conv = ConvBnLelu(transformation_filters, transformation_filters, bn=False, bias=True) + self.final_conv = ConvBnLelu(transformation_filters, 3, bn=False, lelu=False, bias=True) for filters, growth, sw_reduce, sw_proc, trans_count, kernel, layers in zip(switch_filters, switch_growths, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes, trans_layers): multiplx_fn = functools.partial(ConvBasisMultiplexer, transformation_filters, filters, growth, sw_reduce, sw_proc, trans_count) switches.append(ConfigurableSwitchComputer(transformation_filters, multiplx_fn, pre_transform_block=functools.partial(ConvBnLelu, transformation_filters, transformation_filters, bn=False, bias=False), - transform_block=functools.partial(MultiConvBlock, transformation_filters, transformation_filters, transformation_filters, kernel_size=kernel, depth=layers), + transform_block=functools.partial(MultiConvBlock, transformation_filters, transformation_filters + growth, transformation_filters, kernel_size=kernel, depth=layers), transform_count=trans_count, init_temp=initial_temp, enable_negative_transforms=enable_negative_transforms, - add_scalable_noise_to_transforms=add_scalable_noise_to_transforms, init_scalar=1)) - # Initialize the transforms with a lesser weight, since they are repeatedly added on to the resultant image. - initialize_weights([s.transforms for s in switches], .2 / len(switches)) + add_scalable_noise_to_transforms=add_scalable_noise_to_transforms, init_scalar=.2)) self.switches = nn.ModuleList(switches) self.transformation_counts = trans_counts @@ -178,16 +180,18 @@ class ConfigurableSwitchedResidualGenerator2(nn.Module): x = self.initial_conv(x) self.attentions = [] + swx = x for i, sw in enumerate(self.switches): - x, att = sw.forward(x, True) + swx, att = sw.forward(swx, True) self.attentions.append(att) + x = swx + self.sw_conv(x) - if self.upsample_factor > 1: - x = F.interpolate(x, scale_factor=self.upsample_factor, mode="nearest") - - x = self.proc_conv(x) - x = self.final_conv(x) - return x, + assert x == 2 or x == 4 + x = self.upconv1(F.interpolate(x, scale_factor=2, mode="nearest")) + if self.upsample_factor > 2: + x = F.interpolate(x, scale_factor=2, mode="nearest") + x = self.upconv2(x) + return self.final_conv(self.hr_conv(x)), def set_temperature(self, temp): [sw.set_temperature(temp) for sw in self.switches]