Fix spinenet bug
This commit is contained in:
parent
b008a27d39
commit
e706911c83
|
@ -299,16 +299,15 @@ class SpineNet(nn.Module):
|
|||
constant_init(m.bn2, 0)
|
||||
|
||||
def forward(self, input):
|
||||
# Spinenet is pretrained on the standard pytorch input norm. The image will need to
|
||||
# be normalized before feeding it through.
|
||||
if self.conv1 is not None:
|
||||
if self.use_input_norm:
|
||||
mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(input.device)
|
||||
std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(input.device)
|
||||
input = (input - mean) / std
|
||||
|
||||
if self.conv1 is not None:
|
||||
feat = self.conv1(input)
|
||||
feat = self.maxpool(feat)
|
||||
else:
|
||||
feat = input
|
||||
feat1 = self.init_block1(feat)
|
||||
feat2 = self.init_block2(feat1)
|
||||
block_feats = [feat1, feat2]
|
||||
|
|
Loading…
Reference in New Issue
Block a user