Remove batchnorms from resgen

This commit is contained in:
James Betker 2020-06-22 17:23:36 -06:00
parent 68bcab03ae
commit 030648f2bc

View File

@ -54,8 +54,8 @@ class ResidualBranch(nn.Module):
class HalvingProcessingBlock(nn.Module):
def __init__(self, filters):
super(HalvingProcessingBlock, self).__init__()
self.bnconv1 = ConvBnLelu(filters, filters * 2, stride=2)
self.bnconv2 = ConvBnLelu(filters * 2, filters * 2)
self.bnconv1 = ConvBnLelu(filters, filters * 2, stride=2, bn=False)
self.bnconv2 = ConvBnLelu(filters * 2, filters * 2, bn=False)
def forward(self, x):
x = self.bnconv1(x)
@ -68,7 +68,7 @@ def create_sequential_growing_processing_block(filters_init, filter_growth, num_
convs = []
current_filters = filters_init
for i in range(num_convs):
convs.append(ConvBnLelu(current_filters, current_filters + filter_growth))
convs.append(ConvBnLelu(current_filters, current_filters + filter_growth, bn=False))
current_filters += filter_growth
return nn.Sequential(*convs), current_filters
@ -81,7 +81,7 @@ class SwitchComputer(nn.Module):
final_filters = filters * 2 ** reduction_blocks
self.processing_blocks, final_filters = create_sequential_growing_processing_block(final_filters, growth, processing_blocks)
proc_block_filters = max(final_filters // 2, transform_count)
self.proc_switch_conv = ConvBnLelu(final_filters, proc_block_filters)
self.proc_switch_conv = ConvBnLelu(final_filters, proc_block_filters, bn=False)
self.final_switch_conv = nn.Conv2d(proc_block_filters, transform_count, 1, 1, 0)
self.transforms = nn.ModuleList([transform_block() for i in range(transform_count)])