From 78276afcaad8149265fb262e016e411152518c4a Mon Sep 17 00:00:00 2001 From: James Betker Date: Wed, 1 Jul 2020 11:43:25 -0600 Subject: [PATCH] Experiment: Back to lelu --- codes/models/archs/SwitchedResidualGenerator_arch.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/codes/models/archs/SwitchedResidualGenerator_arch.py b/codes/models/archs/SwitchedResidualGenerator_arch.py index 27e77775..dc5e6a18 100644 --- a/codes/models/archs/SwitchedResidualGenerator_arch.py +++ b/codes/models/archs/SwitchedResidualGenerator_arch.py @@ -116,7 +116,7 @@ def create_sequential_growing_processing_block(filters_init, filter_growth, num_ convs = [] current_filters = filters_init for i in range(num_convs): - convs.append(ConvBnRelu(current_filters, current_filters + filter_growth, bn=True, bias=False)) + convs.append(ConvBnLelu(current_filters, current_filters + filter_growth, bn=True, bias=False)) current_filters += filter_growth return nn.Sequential(*convs), current_filters @@ -222,15 +222,15 @@ class ConfigurableSwitchComputer(nn.Module): class ConvBasisMultiplexer(nn.Module): def __init__(self, input_channels, base_filters, growth, reductions, processing_depth, multiplexer_channels, use_bn=True): super(ConvBasisMultiplexer, self).__init__() - self.filter_conv = ConvBnRelu(input_channels, base_filters, bias=True) + self.filter_conv = ConvBnLelu(input_channels, base_filters, bias=True) self.reduction_blocks = nn.Sequential(OrderedDict([('block%i:' % (i,), HalvingProcessingBlock(base_filters * 2 ** i)) for i in range(reductions)])) reduction_filters = base_filters * 2 ** reductions self.processing_blocks, self.output_filter_count = create_sequential_growing_processing_block(reduction_filters, growth, processing_depth) gap = self.output_filter_count - multiplexer_channels - self.cbl1 = ConvBnRelu(self.output_filter_count, self.output_filter_count - (gap // 2), bn=use_bn, bias=False) - self.cbl2 = ConvBnRelu(self.output_filter_count - (gap // 2), self.output_filter_count - (3 * gap // 4), bn=use_bn, bias=False) - self.cbl3 = ConvBnRelu(self.output_filter_count - (3 * gap // 4), multiplexer_channels, bias=True) + self.cbl1 = ConvBnLelu(self.output_filter_count, self.output_filter_count - (gap // 2), bn=use_bn, bias=False) + self.cbl2 = ConvBnLelu(self.output_filter_count - (gap // 2), self.output_filter_count - (3 * gap // 4), bn=use_bn, bias=False) + self.cbl3 = ConvBnLelu(self.output_filter_count - (3 * gap // 4), multiplexer_channels, bias=True) def forward(self, x): x = self.filter_conv(x)