Fix spinenet bug

This commit is contained in:
James Betker 2020-10-17 20:20:36 -06:00
parent b008a27d39
commit e706911c83

View File

@ -299,16 +299,15 @@ class SpineNet(nn.Module):
constant_init(m.bn2, 0)
def forward(self, input):
# Spinenet is pretrained on the standard pytorch input norm. The image will need to
# be normalized before feeding it through.
if self.use_input_norm:
mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(input.device)
std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(input.device)
input = (input - mean) / std
if self.conv1 is not None:
if self.use_input_norm:
mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(input.device)
std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(input.device)
input = (input - mean) / std
feat = self.conv1(input)
feat = self.maxpool(feat)
else:
feat = input
feat1 = self.init_block1(feat)
feat2 = self.init_block2(feat1)
block_feats = [feat1, feat2]