From e706911c83d0d381b57e4a7c703c4f3214016daf Mon Sep 17 00:00:00 2001 From: James Betker Date: Sat, 17 Oct 2020 20:20:36 -0600 Subject: [PATCH] Fix spinenet bug --- codes/models/archs/spinenet_arch.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/codes/models/archs/spinenet_arch.py b/codes/models/archs/spinenet_arch.py index 5dcd84a0..85fb71dd 100644 --- a/codes/models/archs/spinenet_arch.py +++ b/codes/models/archs/spinenet_arch.py @@ -299,16 +299,15 @@ class SpineNet(nn.Module): constant_init(m.bn2, 0) def forward(self, input): - # Spinenet is pretrained on the standard pytorch input norm. The image will need to - # be normalized before feeding it through. - if self.use_input_norm: - mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(input.device) - std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(input.device) - input = (input - mean) / std - if self.conv1 is not None: + if self.use_input_norm: + mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(input.device) + std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(input.device) + input = (input - mean) / std feat = self.conv1(input) feat = self.maxpool(feat) + else: + feat = input feat1 = self.init_block1(feat) feat2 = self.init_block2(feat1) block_feats = [feat1, feat2]