From 030648f2bc3f9cf2dd848373871af2ef95e58317 Mon Sep 17 00:00:00 2001 From: James Betker Date: Mon, 22 Jun 2020 17:23:36 -0600 Subject: [PATCH] Remove batchnorms from resgen --- codes/models/archs/SwitchedResidualGenerator_arch.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/codes/models/archs/SwitchedResidualGenerator_arch.py b/codes/models/archs/SwitchedResidualGenerator_arch.py index 1b43173c..955938d4 100644 --- a/codes/models/archs/SwitchedResidualGenerator_arch.py +++ b/codes/models/archs/SwitchedResidualGenerator_arch.py @@ -54,8 +54,8 @@ class ResidualBranch(nn.Module): class HalvingProcessingBlock(nn.Module): def __init__(self, filters): super(HalvingProcessingBlock, self).__init__() - self.bnconv1 = ConvBnLelu(filters, filters * 2, stride=2) - self.bnconv2 = ConvBnLelu(filters * 2, filters * 2) + self.bnconv1 = ConvBnLelu(filters, filters * 2, stride=2, bn=False) + self.bnconv2 = ConvBnLelu(filters * 2, filters * 2, bn=False) def forward(self, x): x = self.bnconv1(x) @@ -68,7 +68,7 @@ def create_sequential_growing_processing_block(filters_init, filter_growth, num_ convs = [] current_filters = filters_init for i in range(num_convs): - convs.append(ConvBnLelu(current_filters, current_filters + filter_growth)) + convs.append(ConvBnLelu(current_filters, current_filters + filter_growth, bn=False)) current_filters += filter_growth return nn.Sequential(*convs), current_filters @@ -81,7 +81,7 @@ class SwitchComputer(nn.Module): final_filters = filters * 2 ** reduction_blocks self.processing_blocks, final_filters = create_sequential_growing_processing_block(final_filters, growth, processing_blocks) proc_block_filters = max(final_filters // 2, transform_count) - self.proc_switch_conv = ConvBnLelu(final_filters, proc_block_filters) + self.proc_switch_conv = ConvBnLelu(final_filters, proc_block_filters, bn=False) self.final_switch_conv = nn.Conv2d(proc_block_filters, transform_count, 1, 1, 0) self.transforms = nn.ModuleList([transform_block() for i in range(transform_count)])