Experiment: Back to lelu
This commit is contained in:
parent
b945021c90
commit
78276afcaa
|
@ -116,7 +116,7 @@ def create_sequential_growing_processing_block(filters_init, filter_growth, num_
|
|||
convs = []
|
||||
current_filters = filters_init
|
||||
for i in range(num_convs):
|
||||
convs.append(ConvBnRelu(current_filters, current_filters + filter_growth, bn=True, bias=False))
|
||||
convs.append(ConvBnLelu(current_filters, current_filters + filter_growth, bn=True, bias=False))
|
||||
current_filters += filter_growth
|
||||
return nn.Sequential(*convs), current_filters
|
||||
|
||||
|
@ -222,15 +222,15 @@ class ConfigurableSwitchComputer(nn.Module):
|
|||
class ConvBasisMultiplexer(nn.Module):
|
||||
def __init__(self, input_channels, base_filters, growth, reductions, processing_depth, multiplexer_channels, use_bn=True):
|
||||
super(ConvBasisMultiplexer, self).__init__()
|
||||
self.filter_conv = ConvBnRelu(input_channels, base_filters, bias=True)
|
||||
self.filter_conv = ConvBnLelu(input_channels, base_filters, bias=True)
|
||||
self.reduction_blocks = nn.Sequential(OrderedDict([('block%i:' % (i,), HalvingProcessingBlock(base_filters * 2 ** i)) for i in range(reductions)]))
|
||||
reduction_filters = base_filters * 2 ** reductions
|
||||
self.processing_blocks, self.output_filter_count = create_sequential_growing_processing_block(reduction_filters, growth, processing_depth)
|
||||
|
||||
gap = self.output_filter_count - multiplexer_channels
|
||||
self.cbl1 = ConvBnRelu(self.output_filter_count, self.output_filter_count - (gap // 2), bn=use_bn, bias=False)
|
||||
self.cbl2 = ConvBnRelu(self.output_filter_count - (gap // 2), self.output_filter_count - (3 * gap // 4), bn=use_bn, bias=False)
|
||||
self.cbl3 = ConvBnRelu(self.output_filter_count - (3 * gap // 4), multiplexer_channels, bias=True)
|
||||
self.cbl1 = ConvBnLelu(self.output_filter_count, self.output_filter_count - (gap // 2), bn=use_bn, bias=False)
|
||||
self.cbl2 = ConvBnLelu(self.output_filter_count - (gap // 2), self.output_filter_count - (3 * gap // 4), bn=use_bn, bias=False)
|
||||
self.cbl3 = ConvBnLelu(self.output_filter_count - (3 * gap // 4), multiplexer_channels, bias=True)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.filter_conv(x)
|
||||
|
|
Loading…
Reference in New Issue
Block a user