From 9429544a600768b8f55aa7ff16713d4e01d96e9e Mon Sep 17 00:00:00 2001 From: James Betker Date: Mon, 21 Sep 2020 12:36:30 -0600 Subject: [PATCH] Spinenet: implementation without 4x downsampling right off the bat --- codes/models/archs/spinenet_arch.py | 28 ++++++++++++++++++++-------- 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/codes/models/archs/spinenet_arch.py b/codes/models/archs/spinenet_arch.py index e813781c..1e5a192b 100644 --- a/codes/models/archs/spinenet_arch.py +++ b/codes/models/archs/spinenet_arch.py @@ -253,7 +253,8 @@ class SpineNet(nn.Module): norm_cfg=dict(type='BN', requires_grad=True), zero_init_residual=True, activation='relu', - use_input_norm=False): + use_input_norm=False, + double_reduce_early=True): super(SpineNet, self).__init__() self._block_specs = build_block_specs()[2:] self._endpoints_num_filters = SCALING_MAP[arch]['endpoints_num_filters'] @@ -262,6 +263,7 @@ class SpineNet(nn.Module): self._filter_size_scale = SCALING_MAP[arch]['filter_size_scale'] self._init_block_fn = Bottleneck self._num_init_blocks = 2 + self._early_double_reduce = double_reduce_early self.zero_init_residual = zero_init_residual assert min(output_level) > 2 and max(output_level) < 8, "Output level out of range" self.output_level = output_level @@ -274,12 +276,20 @@ class SpineNet(nn.Module): def _make_stem_layer(self, in_channels): """Build the stem network.""" # Build the first conv and maxpooling layers. - self.conv1 = ConvBnRelu( - in_channels, - 64, - kernel_size=7, - stride=2) - self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + if self._early_double_reduce: + self.conv1 = ConvBnRelu( + in_channels, + 64, + kernel_size=7, + stride=2) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + else: + self.conv1 = ConvBnRelu( + in_channels, + 64, + kernel_size=7, + stride=1) + self.maxpool = None # Build the initial level 2 blocks. self.init_block1 = make_res_layer( @@ -340,7 +350,9 @@ class SpineNet(nn.Module): std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(input.device) input = (input - mean) / std - feat = self.maxpool(self.conv1(input)) + feat = self.conv1(input) + if self.maxpool: + feat = self.maxpool(feat) feat1 = self.init_block1(feat) feat2 = self.init_block2(feat1) block_feats = [feat1, feat2]