From 8f92c0a0885ae05c2299811cd343e01e09b758a2 Mon Sep 17 00:00:00 2001 From: James Betker Date: Mon, 6 Jul 2020 09:18:30 -0600 Subject: [PATCH] Interpolate attention well before softmax --- codes/models/archs/SRG1_arch.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/codes/models/archs/SRG1_arch.py b/codes/models/archs/SRG1_arch.py index 5350cfbc..826ccd58 100644 --- a/codes/models/archs/SRG1_arch.py +++ b/codes/models/archs/SRG1_arch.py @@ -60,12 +60,14 @@ class SwitchComputer(nn.Module): self.reduction_blocks = nn.ModuleList([HalvingProcessingBlock(filters * 2 ** i) for i in range(reduction_blocks)]) final_filters = filters * 2 ** reduction_blocks self.processing_blocks, final_filters = create_sequential_growing_processing_block(final_filters, growth, processing_blocks) - proc_block_filters = max(final_filters // 2, transform_count) - self.proc_switch_conv = ConvBnLelu(final_filters, proc_block_filters, bn=False) + self.post_interpolate_decimate = ConvBnSilu(final_filters, filters, kernel_size=1, silu=False, bn=False) + self.interpolate_process = ConvBnSilu(filters, filters) + self.interpolate_process2 = ConvBnSilu(filters, filters) tc = transform_count if self.enable_negative_transforms: tc = transform_count * 2 - self.final_switch_conv = nn.Conv2d(proc_block_filters, tc, 1, 1, 0) + assert filters > transform_count * 2 + self.final_switch_conv = nn.Conv2d(filters, tc, 1, 1, 0) self.transforms = nn.ModuleList([transform_block() for _ in range(transform_count)]) self.add_noise = add_scalable_noise_to_transforms @@ -89,10 +91,13 @@ class SwitchComputer(nn.Module): multiplexer = block.forward(multiplexer) for block in self.processing_blocks: multiplexer = block.forward(multiplexer) - multiplexer = self.proc_switch_conv(multiplexer) - multiplexer = self.final_switch_conv.forward(multiplexer) - # Interpolate the multiplexer across the entire shape of the image. + + # Interpolate the multiplexer across the entire shape of the image perform some post-processing before feeding into switch. multiplexer = F.interpolate(multiplexer, size=x.shape[2:], mode='nearest') + multiplexer = self.post_interpolate_decimate(multiplexer) + multiplexer = self.interpolate_process(multiplexer) + multiplexer = self.interpolate_process2(multiplexer) + multiplexer = self.final_switch_conv.forward(multiplexer) outputs, attention = self.switch(xformed, multiplexer, True) outputs = outputs * self.scale + self.bias