diff --git a/codes/models/archs/spinenet_arch.py b/codes/models/archs/spinenet_arch.py index 18c1f1b2..c16ef2a1 100644 --- a/codes/models/archs/spinenet_arch.py +++ b/codes/models/archs/spinenet_arch.py @@ -353,7 +353,8 @@ class SpinenetWithLogits(SpineNet): self.output_to_attach = output_to_attach self.tail = nn.Sequential(ConvBnRelu(256, 128, kernel_size=1, activation=True, norm=True, bias=False), ConvBnRelu(128, 64, kernel_size=1, activation=True, norm=True, bias=False), - ConvBnRelu(64, num_labels, kernel_size=1, activation=False, norm=False, bias=True)) + ConvBnRelu(64, num_labels, kernel_size=1, activation=False, norm=False, bias=True), + nn.Softmax(dim=1)) def forward(self, x): fea = super().forward(x)[self.output_to_attach]