Spinenet with logits head

This commit is contained in:
James Betker 2020-12-15 17:16:19 -07:00
parent 8e0e883050
commit fc376d34b2
4 changed files with 80 additions and 3 deletions

View File

@ -430,6 +430,44 @@ class ConvGnSilu(nn.Module):
return x
''' Convenience class with Conv->BN->ReLU. Includes weight initialization and auto-padding for standard
kernel sizes. '''
class ConvBnRelu(nn.Module):
def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, activation=True, norm=True, bias=True, weight_init_factor=1):
super(ConvBnRelu, self).__init__()
padding_map = {1: 0, 3: 1, 5: 2, 7: 3}
assert kernel_size in padding_map.keys()
self.conv = nn.Conv2d(filters_in, filters_out, kernel_size, stride, padding_map[kernel_size], bias=bias)
if norm:
self.bn = nn.BatchNorm2d(filters_out)
else:
self.bn = None
if activation:
self.relu = nn.ReLU()
else:
self.relu = None
# Init params.
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu' if self.relu else 'linear')
m.weight.data *= weight_init_factor
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self, x):
x = self.conv(x)
if self.bn:
x = self.bn(x)
if self.relu:
return self.relu(x)
else:
return x
# Simple way to chain multiple conv->act->norms together in an intuitive way.
class MultiConvBlock(nn.Module):
def __init__(self, filters_in, filters_mid, filters_out, kernel_size, depth, scale_init=1, norm=False, weight_init_factor=1):

View File

@ -6,7 +6,7 @@ import torch.nn.functional as F
from torch.nn.init import kaiming_normal
from torchvision.models.resnet import BasicBlock, Bottleneck
from models.archs.arch_util import ConvGnSilu
from models.archs.arch_util import ConvGnSilu, ConvBnSilu, ConvBnRelu
def constant_init(module, val, bias=0):
@ -332,4 +332,29 @@ class SpineNet(nn.Module):
if spec.is_output:
output_feat[spec.level] = target_feat
return tuple([self.endpoint_convs[str(level)](output_feat[level]) for level in self.output_level])
return tuple([self.endpoint_convs[str(level)](output_feat[level]) for level in self.output_level])
# Attachs a simple 1x1 conv prediction head to a Spinenet.
class SpinenetWithLogits(SpineNet):
def __init__(self,
arch,
output_to_attach,
num_labels,
in_channels=3,
output_level=[3, 4],
conv_cfg=None,
norm_cfg=dict(type='BN', requires_grad=True),
zero_init_residual=True,
activation='relu',
use_input_norm=False,
double_reduce_early=True):
super().__init__(arch, in_channels, output_level, conv_cfg, norm_cfg, zero_init_residual, activation, use_input_norm, double_reduce_early)
self.output_to_attach = output_to_attach
self.tail = nn.Sequential(ConvBnRelu(256, 128, kernel_size=1, activation=True, norm=True, bias=False),
ConvBnRelu(128, 64, kernel_size=1, activation=True, norm=True, bias=False),
ConvBnRelu(64, num_labels, kernel_size=1, activation=False, norm=False, bias=True))
def forward(self, x):
fea = super().forward(x)[self.output_to_attach]
return self.tail(fea)

View File

@ -162,6 +162,10 @@ def define_G(opt, opt_net, scale=None):
elif which_model == 'spinenet':
from models.archs.spinenet_arch import SpineNet
netG = SpineNet(str(opt_net['arch']), in_channels=3, use_input_norm=opt_net['use_input_norm'])
elif which_model == 'spinenet_with_logits':
from models.archs.spinenet_arch import SpinenetWithLogits
netG = SpinenetWithLogits(str(opt_net['arch']), opt_net['output_to_attach'], opt_net['num_labels'],
in_channels=3, use_input_norm=opt_net['use_input_norm'])
else:
raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))
return netG

View File

@ -98,8 +98,18 @@ class CrossEntropy(ConfigurableLoss):
self.ce = nn.CrossEntropyLoss()
def forward(self, _, state):
labels = state[self.opt['labels']]
logits = state[self.opt['logits']]
labels = state[self.opt['labels']]
if self.opt['rescale']:
labels = F.interpolate(labels.type(torch.float), size=logits.shape[2:], mode="nearest").type(torch.long)
if 'mask' in self.opt.keys():
mask = state[self.opt['mask']]
if self.opt['rescale']:
mask = F.interpolate(mask, size=logits.shape[2:], mode="nearest")
logits = logits * mask
if self.opt['swap_channels']:
logits = logits.permute(0,2,3,1).contiguous()
assert labels.max()+1 <= logits.shape[-1]
return self.ce(logits.view(-1, logits.size(-1)), labels.view(-1))