diff --git a/codes/models/archs/arch_util.py b/codes/models/archs/arch_util.py index 4295e908..7748be6c 100644 --- a/codes/models/archs/arch_util.py +++ b/codes/models/archs/arch_util.py @@ -430,6 +430,44 @@ class ConvGnSilu(nn.Module): return x +''' Convenience class with Conv->BN->ReLU. Includes weight initialization and auto-padding for standard + kernel sizes. ''' +class ConvBnRelu(nn.Module): + def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, activation=True, norm=True, bias=True, weight_init_factor=1): + super(ConvBnRelu, self).__init__() + padding_map = {1: 0, 3: 1, 5: 2, 7: 3} + assert kernel_size in padding_map.keys() + self.conv = nn.Conv2d(filters_in, filters_out, kernel_size, stride, padding_map[kernel_size], bias=bias) + if norm: + self.bn = nn.BatchNorm2d(filters_out) + else: + self.bn = None + if activation: + self.relu = nn.ReLU() + else: + self.relu = None + + # Init params. + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu' if self.relu else 'linear') + m.weight.data *= weight_init_factor + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def forward(self, x): + x = self.conv(x) + if self.bn: + x = self.bn(x) + if self.relu: + return self.relu(x) + else: + return x + + # Simple way to chain multiple conv->act->norms together in an intuitive way. class MultiConvBlock(nn.Module): def __init__(self, filters_in, filters_mid, filters_out, kernel_size, depth, scale_init=1, norm=False, weight_init_factor=1): diff --git a/codes/models/archs/spinenet_arch.py b/codes/models/archs/spinenet_arch.py index 3e4924f5..18c1f1b2 100644 --- a/codes/models/archs/spinenet_arch.py +++ b/codes/models/archs/spinenet_arch.py @@ -6,7 +6,7 @@ import torch.nn.functional as F from torch.nn.init import kaiming_normal from torchvision.models.resnet import BasicBlock, Bottleneck -from models.archs.arch_util import ConvGnSilu +from models.archs.arch_util import ConvGnSilu, ConvBnSilu, ConvBnRelu def constant_init(module, val, bias=0): @@ -332,4 +332,29 @@ class SpineNet(nn.Module): if spec.is_output: output_feat[spec.level] = target_feat - return tuple([self.endpoint_convs[str(level)](output_feat[level]) for level in self.output_level]) \ No newline at end of file + return tuple([self.endpoint_convs[str(level)](output_feat[level]) for level in self.output_level]) + + +# Attachs a simple 1x1 conv prediction head to a Spinenet. +class SpinenetWithLogits(SpineNet): + def __init__(self, + arch, + output_to_attach, + num_labels, + in_channels=3, + output_level=[3, 4], + conv_cfg=None, + norm_cfg=dict(type='BN', requires_grad=True), + zero_init_residual=True, + activation='relu', + use_input_norm=False, + double_reduce_early=True): + super().__init__(arch, in_channels, output_level, conv_cfg, norm_cfg, zero_init_residual, activation, use_input_norm, double_reduce_early) + self.output_to_attach = output_to_attach + self.tail = nn.Sequential(ConvBnRelu(256, 128, kernel_size=1, activation=True, norm=True, bias=False), + ConvBnRelu(128, 64, kernel_size=1, activation=True, norm=True, bias=False), + ConvBnRelu(64, num_labels, kernel_size=1, activation=False, norm=False, bias=True)) + + def forward(self, x): + fea = super().forward(x)[self.output_to_attach] + return self.tail(fea) diff --git a/codes/models/networks.py b/codes/models/networks.py index 0f504e23..bc947f2a 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -162,6 +162,10 @@ def define_G(opt, opt_net, scale=None): elif which_model == 'spinenet': from models.archs.spinenet_arch import SpineNet netG = SpineNet(str(opt_net['arch']), in_channels=3, use_input_norm=opt_net['use_input_norm']) + elif which_model == 'spinenet_with_logits': + from models.archs.spinenet_arch import SpinenetWithLogits + netG = SpinenetWithLogits(str(opt_net['arch']), opt_net['output_to_attach'], opt_net['num_labels'], + in_channels=3, use_input_norm=opt_net['use_input_norm']) else: raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model)) return netG diff --git a/codes/models/steps/losses.py b/codes/models/steps/losses.py index 6c42a205..84e12e3f 100644 --- a/codes/models/steps/losses.py +++ b/codes/models/steps/losses.py @@ -98,8 +98,18 @@ class CrossEntropy(ConfigurableLoss): self.ce = nn.CrossEntropyLoss() def forward(self, _, state): - labels = state[self.opt['labels']] logits = state[self.opt['logits']] + labels = state[self.opt['labels']] + if self.opt['rescale']: + labels = F.interpolate(labels.type(torch.float), size=logits.shape[2:], mode="nearest").type(torch.long) + if 'mask' in self.opt.keys(): + mask = state[self.opt['mask']] + if self.opt['rescale']: + mask = F.interpolate(mask, size=logits.shape[2:], mode="nearest") + logits = logits * mask + if self.opt['swap_channels']: + logits = logits.permute(0,2,3,1).contiguous() + assert labels.max()+1 <= logits.shape[-1] return self.ce(logits.view(-1, logits.size(-1)), labels.view(-1))