Spinenet with logits head
This commit is contained in:
parent
8e0e883050
commit
fc376d34b2
|
@ -430,6 +430,44 @@ class ConvGnSilu(nn.Module):
|
|||
return x
|
||||
|
||||
|
||||
''' Convenience class with Conv->BN->ReLU. Includes weight initialization and auto-padding for standard
|
||||
kernel sizes. '''
|
||||
class ConvBnRelu(nn.Module):
|
||||
def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, activation=True, norm=True, bias=True, weight_init_factor=1):
|
||||
super(ConvBnRelu, self).__init__()
|
||||
padding_map = {1: 0, 3: 1, 5: 2, 7: 3}
|
||||
assert kernel_size in padding_map.keys()
|
||||
self.conv = nn.Conv2d(filters_in, filters_out, kernel_size, stride, padding_map[kernel_size], bias=bias)
|
||||
if norm:
|
||||
self.bn = nn.BatchNorm2d(filters_out)
|
||||
else:
|
||||
self.bn = None
|
||||
if activation:
|
||||
self.relu = nn.ReLU()
|
||||
else:
|
||||
self.relu = None
|
||||
|
||||
# Init params.
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu' if self.relu else 'linear')
|
||||
m.weight.data *= weight_init_factor
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
if self.bn:
|
||||
x = self.bn(x)
|
||||
if self.relu:
|
||||
return self.relu(x)
|
||||
else:
|
||||
return x
|
||||
|
||||
|
||||
# Simple way to chain multiple conv->act->norms together in an intuitive way.
|
||||
class MultiConvBlock(nn.Module):
|
||||
def __init__(self, filters_in, filters_mid, filters_out, kernel_size, depth, scale_init=1, norm=False, weight_init_factor=1):
|
||||
|
|
|
@ -6,7 +6,7 @@ import torch.nn.functional as F
|
|||
from torch.nn.init import kaiming_normal
|
||||
|
||||
from torchvision.models.resnet import BasicBlock, Bottleneck
|
||||
from models.archs.arch_util import ConvGnSilu
|
||||
from models.archs.arch_util import ConvGnSilu, ConvBnSilu, ConvBnRelu
|
||||
|
||||
|
||||
def constant_init(module, val, bias=0):
|
||||
|
@ -332,4 +332,29 @@ class SpineNet(nn.Module):
|
|||
if spec.is_output:
|
||||
output_feat[spec.level] = target_feat
|
||||
|
||||
return tuple([self.endpoint_convs[str(level)](output_feat[level]) for level in self.output_level])
|
||||
return tuple([self.endpoint_convs[str(level)](output_feat[level]) for level in self.output_level])
|
||||
|
||||
|
||||
# Attachs a simple 1x1 conv prediction head to a Spinenet.
|
||||
class SpinenetWithLogits(SpineNet):
|
||||
def __init__(self,
|
||||
arch,
|
||||
output_to_attach,
|
||||
num_labels,
|
||||
in_channels=3,
|
||||
output_level=[3, 4],
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN', requires_grad=True),
|
||||
zero_init_residual=True,
|
||||
activation='relu',
|
||||
use_input_norm=False,
|
||||
double_reduce_early=True):
|
||||
super().__init__(arch, in_channels, output_level, conv_cfg, norm_cfg, zero_init_residual, activation, use_input_norm, double_reduce_early)
|
||||
self.output_to_attach = output_to_attach
|
||||
self.tail = nn.Sequential(ConvBnRelu(256, 128, kernel_size=1, activation=True, norm=True, bias=False),
|
||||
ConvBnRelu(128, 64, kernel_size=1, activation=True, norm=True, bias=False),
|
||||
ConvBnRelu(64, num_labels, kernel_size=1, activation=False, norm=False, bias=True))
|
||||
|
||||
def forward(self, x):
|
||||
fea = super().forward(x)[self.output_to_attach]
|
||||
return self.tail(fea)
|
||||
|
|
|
@ -162,6 +162,10 @@ def define_G(opt, opt_net, scale=None):
|
|||
elif which_model == 'spinenet':
|
||||
from models.archs.spinenet_arch import SpineNet
|
||||
netG = SpineNet(str(opt_net['arch']), in_channels=3, use_input_norm=opt_net['use_input_norm'])
|
||||
elif which_model == 'spinenet_with_logits':
|
||||
from models.archs.spinenet_arch import SpinenetWithLogits
|
||||
netG = SpinenetWithLogits(str(opt_net['arch']), opt_net['output_to_attach'], opt_net['num_labels'],
|
||||
in_channels=3, use_input_norm=opt_net['use_input_norm'])
|
||||
else:
|
||||
raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))
|
||||
return netG
|
||||
|
|
|
@ -98,8 +98,18 @@ class CrossEntropy(ConfigurableLoss):
|
|||
self.ce = nn.CrossEntropyLoss()
|
||||
|
||||
def forward(self, _, state):
|
||||
labels = state[self.opt['labels']]
|
||||
logits = state[self.opt['logits']]
|
||||
labels = state[self.opt['labels']]
|
||||
if self.opt['rescale']:
|
||||
labels = F.interpolate(labels.type(torch.float), size=logits.shape[2:], mode="nearest").type(torch.long)
|
||||
if 'mask' in self.opt.keys():
|
||||
mask = state[self.opt['mask']]
|
||||
if self.opt['rescale']:
|
||||
mask = F.interpolate(mask, size=logits.shape[2:], mode="nearest")
|
||||
logits = logits * mask
|
||||
if self.opt['swap_channels']:
|
||||
logits = logits.permute(0,2,3,1).contiguous()
|
||||
assert labels.max()+1 <= logits.shape[-1]
|
||||
return self.ce(logits.view(-1, logits.size(-1)), labels.view(-1))
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user