Remove batchnorms from resgen
This commit is contained in:
parent
68bcab03ae
commit
030648f2bc
|
@ -54,8 +54,8 @@ class ResidualBranch(nn.Module):
|
|||
class HalvingProcessingBlock(nn.Module):
|
||||
def __init__(self, filters):
|
||||
super(HalvingProcessingBlock, self).__init__()
|
||||
self.bnconv1 = ConvBnLelu(filters, filters * 2, stride=2)
|
||||
self.bnconv2 = ConvBnLelu(filters * 2, filters * 2)
|
||||
self.bnconv1 = ConvBnLelu(filters, filters * 2, stride=2, bn=False)
|
||||
self.bnconv2 = ConvBnLelu(filters * 2, filters * 2, bn=False)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.bnconv1(x)
|
||||
|
@ -68,7 +68,7 @@ def create_sequential_growing_processing_block(filters_init, filter_growth, num_
|
|||
convs = []
|
||||
current_filters = filters_init
|
||||
for i in range(num_convs):
|
||||
convs.append(ConvBnLelu(current_filters, current_filters + filter_growth))
|
||||
convs.append(ConvBnLelu(current_filters, current_filters + filter_growth, bn=False))
|
||||
current_filters += filter_growth
|
||||
return nn.Sequential(*convs), current_filters
|
||||
|
||||
|
@ -81,7 +81,7 @@ class SwitchComputer(nn.Module):
|
|||
final_filters = filters * 2 ** reduction_blocks
|
||||
self.processing_blocks, final_filters = create_sequential_growing_processing_block(final_filters, growth, processing_blocks)
|
||||
proc_block_filters = max(final_filters // 2, transform_count)
|
||||
self.proc_switch_conv = ConvBnLelu(final_filters, proc_block_filters)
|
||||
self.proc_switch_conv = ConvBnLelu(final_filters, proc_block_filters, bn=False)
|
||||
self.final_switch_conv = nn.Conv2d(proc_block_filters, transform_count, 1, 1, 0)
|
||||
|
||||
self.transforms = nn.ModuleList([transform_block() for i in range(transform_count)])
|
||||
|
|
Loading…
Reference in New Issue
Block a user