From 0584c3b587397c6879cdee00b66ea7db5f90b2c4 Mon Sep 17 00:00:00 2001 From: James Betker Date: Tue, 23 Jun 2020 09:41:12 -0600 Subject: [PATCH] Add negative_transforms switch to resgen --- .../models/archs/SwitchedResidualGenerator_arch.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/codes/models/archs/SwitchedResidualGenerator_arch.py b/codes/models/archs/SwitchedResidualGenerator_arch.py index 955938d4..ed5cbf5b 100644 --- a/codes/models/archs/SwitchedResidualGenerator_arch.py +++ b/codes/models/archs/SwitchedResidualGenerator_arch.py @@ -74,17 +74,23 @@ def create_sequential_growing_processing_block(filters_init, filter_growth, num_ class SwitchComputer(nn.Module): - def __init__(self, channels_in, filters, growth, transform_block, transform_count, reduction_blocks, processing_blocks=0, init_temp=20): + def __init__(self, channels_in, filters, growth, transform_block, transform_count, reduction_blocks, processing_blocks=0, + init_temp=20, enable_negative_transforms=False): super(SwitchComputer, self).__init__() + self.enable_negative_transforms = enable_negative_transforms + self.filter_conv = ConvBnLelu(channels_in, filters) self.reduction_blocks = nn.ModuleList([HalvingProcessingBlock(filters * 2 ** i) for i in range(reduction_blocks)]) final_filters = filters * 2 ** reduction_blocks self.processing_blocks, final_filters = create_sequential_growing_processing_block(final_filters, growth, processing_blocks) proc_block_filters = max(final_filters // 2, transform_count) self.proc_switch_conv = ConvBnLelu(final_filters, proc_block_filters, bn=False) - self.final_switch_conv = nn.Conv2d(proc_block_filters, transform_count, 1, 1, 0) + tc = transform_count + if self.enable_negative_transforms: + tc = transform_count * 2 + self.final_switch_conv = nn.Conv2d(proc_block_filters, tc, 1, 1, 0) - self.transforms = nn.ModuleList([transform_block() for i in range(transform_count)]) + self.transforms = nn.ModuleList([transform_block() for _ in range(transform_count)]) # And the switch itself, including learned scalars self.switch = BareConvSwitch(initial_temperature=init_temp) @@ -93,6 +99,8 @@ class SwitchComputer(nn.Module): def forward(self, x, output_attention_weights=False): xformed = [t.forward(x) for t in self.transforms] + if self.enable_negative_transforms: + xformed.extend([-t for t in xformed]) multiplexer = self.filter_conv(x) for block in self.reduction_blocks: