forked from mrq/DL-Art-School
Spinenet: implementation without 4x downsampling right off the bat
This commit is contained in:
parent
384e3d54cc
commit
9429544a60
|
@ -253,7 +253,8 @@ class SpineNet(nn.Module):
|
|||
norm_cfg=dict(type='BN', requires_grad=True),
|
||||
zero_init_residual=True,
|
||||
activation='relu',
|
||||
use_input_norm=False):
|
||||
use_input_norm=False,
|
||||
double_reduce_early=True):
|
||||
super(SpineNet, self).__init__()
|
||||
self._block_specs = build_block_specs()[2:]
|
||||
self._endpoints_num_filters = SCALING_MAP[arch]['endpoints_num_filters']
|
||||
|
@ -262,6 +263,7 @@ class SpineNet(nn.Module):
|
|||
self._filter_size_scale = SCALING_MAP[arch]['filter_size_scale']
|
||||
self._init_block_fn = Bottleneck
|
||||
self._num_init_blocks = 2
|
||||
self._early_double_reduce = double_reduce_early
|
||||
self.zero_init_residual = zero_init_residual
|
||||
assert min(output_level) > 2 and max(output_level) < 8, "Output level out of range"
|
||||
self.output_level = output_level
|
||||
|
@ -274,12 +276,20 @@ class SpineNet(nn.Module):
|
|||
def _make_stem_layer(self, in_channels):
|
||||
"""Build the stem network."""
|
||||
# Build the first conv and maxpooling layers.
|
||||
if self._early_double_reduce:
|
||||
self.conv1 = ConvBnRelu(
|
||||
in_channels,
|
||||
64,
|
||||
kernel_size=7,
|
||||
stride=2)
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
else:
|
||||
self.conv1 = ConvBnRelu(
|
||||
in_channels,
|
||||
64,
|
||||
kernel_size=7,
|
||||
stride=1)
|
||||
self.maxpool = None
|
||||
|
||||
# Build the initial level 2 blocks.
|
||||
self.init_block1 = make_res_layer(
|
||||
|
@ -340,7 +350,9 @@ class SpineNet(nn.Module):
|
|||
std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(input.device)
|
||||
input = (input - mean) / std
|
||||
|
||||
feat = self.maxpool(self.conv1(input))
|
||||
feat = self.conv1(input)
|
||||
if self.maxpool:
|
||||
feat = self.maxpool(feat)
|
||||
feat1 = self.init_block1(feat)
|
||||
feat2 = self.init_block2(feat1)
|
||||
block_feats = [feat1, feat2]
|
||||
|
|
Loading…
Reference in New Issue
Block a user