From 604763be6863b5ff25df5f7737923454138dd2db Mon Sep 17 00:00:00 2001 From: James Betker Date: Wed, 1 Jul 2020 09:54:29 -0600 Subject: [PATCH] NSG r7 Converts the switching trunk to a VGG-style network to make it more comparable to SRG architectures. --- codes/models/archs/NestedSwitchGenerator.py | 107 +++--------------- .../archs/SwitchedResidualGenerator_arch.py | 47 +++++++- 2 files changed, 57 insertions(+), 97 deletions(-) diff --git a/codes/models/archs/NestedSwitchGenerator.py b/codes/models/archs/NestedSwitchGenerator.py index 72b7f2bb..f663f4cc 100644 --- a/codes/models/archs/NestedSwitchGenerator.py +++ b/codes/models/archs/NestedSwitchGenerator.py @@ -1,64 +1,11 @@ import torch from torch import nn -from models.archs.SwitchedResidualGenerator_arch import ConvBnLelu, MultiConvBlock, initialize_weights +from models.archs.SwitchedResidualGenerator_arch import ConvBnLelu, ConvBnRelu, MultiConvBlock, initialize_weights from switched_conv import BareConvSwitch, compute_attention_specificity from switched_conv_util import save_attention_to_image from functools import partial import torch.nn.functional as F -import numpy as np - - -def conv3x3(in_planes, out_planes, stride=1): - """3x3 convolution with padding""" - return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, - padding=1, bias=False) - - -def conv1x1(in_planes, out_planes, stride=1): - """1x1 convolution""" - return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) - - -# Taken from Fixup resnet implementation https://github.com/hongyi-zhang/Fixup/blob/master/imagenet/models/fixup_resnet_imagenet.py -class FixupBottleneck(nn.Module): - expansion = 4 - - def __init__(self, inplanes, planes, stride=1, downsample=None): - super(FixupBottleneck, self).__init__() - # Both self.conv2 and self.downsample layers downsample the input when stride != 1 - self.bias1a = nn.Parameter(torch.zeros(1)) - self.conv1 = conv1x1(inplanes, planes) - self.bias1b = nn.Parameter(torch.zeros(1)) - self.bias2a = nn.Parameter(torch.zeros(1)) - self.conv2 = conv3x3(planes, planes, stride) - self.bias2b = nn.Parameter(torch.zeros(1)) - self.bias3a = nn.Parameter(torch.zeros(1)) - self.conv3 = conv1x1(planes, planes * self.expansion) - self.scale = nn.Parameter(torch.ones(1)) - self.bias3b = nn.Parameter(torch.zeros(1)) - self.relu = nn.ReLU(inplace=True) - self.downsample = downsample - self.stride = stride - - def forward(self, x): - identity = x - - out = self.conv1(x + self.bias1a) - out = self.relu(out + self.bias1b) - - out = self.conv2(out + self.bias2a) - out = self.relu(out + self.bias2b) - - out = self.conv3(out + self.bias3a) - out = out * self.scale + self.bias3b - - if self.downsample is not None: - identity = self.downsample(x + self.bias1a) - - out += identity - out = self.relu(out) - - return out +from collections import OrderedDict class Switch(nn.Module): @@ -110,30 +57,21 @@ class Switch(nn.Module): # Convolutional image processing block that optionally reduces image size by a factor of 2 using stride and performs a -# series of residual-block-like processing operations on it. +# series of conv blocks on it class Processor(nn.Module): def __init__(self, base_filters, processing_depth, reduce=False): super(Processor, self).__init__() self.output_filter_count = base_filters * (2 if reduce else 1) - - # Downsample block used for bottleneck. - if reduce: - downsample = nn.Sequential( - nn.Conv2d(base_filters, self.output_filter_count, kernel_size=1, stride=2, bias=False), - nn.BatchNorm2d(self.output_filter_count), - ) - else: - downsample = None - # Bottleneck block outputs the requested filter sizex4, but we only want x2. - self.initial = FixupBottleneck(base_filters, self.output_filter_count // 4, stride=2 if reduce else 1, downsample=downsample) - self.res_blocks = nn.ModuleList([FixupBottleneck(self.output_filter_count, self.output_filter_count // 4) for _ in range(processing_depth)]) + self.pre = ConvBnRelu(base_filters, base_filters, kernel_size=3, bias=True) + self.initial = ConvBnRelu(base_filters, self.output_filter_count, kernel_size=1, stride=2 if reduce else 1, bias=False) + self.blocks = nn.Sequential(OrderedDict( + [(str(i), ConvBnRelu(self.output_filter_count, self.output_filter_count, kernel_size=3, bias=False)) for i in range(processing_depth)])) def forward(self, x): - x = (self.initial(x) - .4) / .6 - for b in self.res_blocks: - r = (b(x) - .4) / .6 - x = r + x - return x + x = self.pre(x) + x = self.initial(x) + x = self.blocks(x) + return (x - .39) / .58 # Convolutional image processing block that constricts an input image with a large number of filters to a small number @@ -144,15 +82,15 @@ class Constrictor(nn.Module): assert(filters > output_filters) gap = filters - output_filters gap_div_4 = int(gap / 4) - self.cbl1 = ConvBnLelu(filters, filters - (gap_div_4 * 2), kernel_size=1, bn=True) - self.cbl2 = ConvBnLelu(filters - (gap_div_4 * 2), filters - (gap_div_4 * 3), kernel_size=1, bn=True) - self.cbl3 = ConvBnLelu(filters - (gap_div_4 * 3), output_filters, kernel_size=1, lelu=False, bn=False) + self.cbl1 = ConvBnRelu(filters, filters - (gap_div_4 * 2), kernel_size=1, bn=True, bias=True) + self.cbl2 = ConvBnRelu(filters - (gap_div_4 * 2), filters - (gap_div_4 * 3), kernel_size=1, bn=True, bias=False) + self.cbl3 = ConvBnRelu(filters - (gap_div_4 * 3), output_filters, kernel_size=1, relu=False, bn=False, bias=False) def forward(self, x): x = self.cbl1(x) x = self.cbl2(x) - x = self.cbl3(x) / 4 - return x + x = self.cbl3(x) + return x / 2.67 class RecursiveSwitchedTransform(nn.Module): @@ -200,19 +138,6 @@ class NestedSwitchComputer(nn.Module): self.switch = RecursiveSwitchedTransform(transform_filters, filters, nesting_depth-1, transforms_at_leaf, trans_kernel_size, trans_num_layers-1, trans_scale_init, initial_temp=initial_temp, add_scalable_noise_to_transforms=add_scalable_noise_to_transforms) self.anneal = ConvBnLelu(transform_filters, transform_filters, kernel_size=1, bn=False) - # Init the parameters in the trunk. Uses the fixup algorithm for residual conv initialization. - self.num_layers = nesting_depth + nesting_depth * num_switch_processing_layers - for m in self.processing_trunk.modules(): - if isinstance(m, FixupBottleneck): - nn.init.normal_(m.conv1.weight, mean=0, std=np.sqrt(2 / (m.conv1.weight.shape[0] * np.prod(m.conv1.weight.shape[2:]))) * self.num_layers ** (-0.25)) - nn.init.normal_(m.conv2.weight, mean=0, std=np.sqrt(2 / (m.conv2.weight.shape[0] * np.prod(m.conv2.weight.shape[2:]))) * self.num_layers ** (-0.25)) - nn.init.constant_(m.conv3.weight, 0) - if m.downsample is not None: - nn.init.normal_(m.downsample[0].weight, mean=0, std=np.sqrt(2 / (m.downsample[0].weight.shape[0] * np.prod(m.downsample[0].weight.shape[2:])))) - elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): - nn.init.constant_(m.weight, 1) - nn.init.constant_(m.bias, 0) - def forward(self, x): feed_forward = x trunk = [] diff --git a/codes/models/archs/SwitchedResidualGenerator_arch.py b/codes/models/archs/SwitchedResidualGenerator_arch.py index 38d3846a..e587e296 100644 --- a/codes/models/archs/SwitchedResidualGenerator_arch.py +++ b/codes/models/archs/SwitchedResidualGenerator_arch.py @@ -7,14 +7,49 @@ from collections import OrderedDict from models.archs.arch_util import initialize_weights from switched_conv_util import save_attention_to_image +''' Convenience class with Conv->BN->ReLU. Includes weight initialization and auto-padding for standard + kernel sizes. ''' +class ConvBnRelu(nn.Module): + def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, relu=True, bn=True, bias=True): + super(ConvBnRelu, self).__init__() + padding_map = {1: 0, 3: 1, 5: 2, 7: 3} + assert kernel_size in padding_map.keys() + self.conv = nn.Conv2d(filters_in, filters_out, kernel_size, stride, padding_map[kernel_size], bias=bias) + if bn: + self.bn = nn.BatchNorm2d(filters_out) + else: + self.bn = None + if relu: + self.relu = nn.ReLU() + else: + self.relu = None -''' Convenience class with Conv->BN->LeakyRelu. Includes Kaiming weight initialization. ''' + # Init params. + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu' if self.relu else 'linear') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def forward(self, x): + x = self.conv(x) + if self.bn: + x = self.bn(x) + if self.relu: + return self.relu(x) + else: + return x + + +''' Convenience class with Conv->BN->ReLU. Includes weight initialization and auto-padding for standard + kernel sizes. ''' class ConvBnLelu(nn.Module): - def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, lelu=True, bn=True): + def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, lelu=True, bn=True, bias=True): super(ConvBnLelu, self).__init__() padding_map = {1: 0, 3: 1, 5: 2, 7: 3} assert kernel_size in padding_map.keys() - self.conv = nn.Conv2d(filters_in, filters_out, kernel_size, stride, padding_map[kernel_size]) + self.conv = nn.Conv2d(filters_in, filters_out, kernel_size, stride, padding_map[kernel_size], bias=bias) if bn: self.bn = nn.BatchNorm2d(filters_out) else: @@ -47,9 +82,9 @@ class MultiConvBlock(nn.Module): assert depth >= 2 super(MultiConvBlock, self).__init__() self.noise_scale = nn.Parameter(torch.full((1,), fill_value=.01)) - self.bnconvs = nn.ModuleList([ConvBnLelu(filters_in, filters_mid, kernel_size, bn=bn)] + - [ConvBnLelu(filters_mid, filters_mid, kernel_size, bn=bn) for i in range(depth-2)] + - [ConvBnLelu(filters_mid, filters_out, kernel_size, lelu=False, bn=False)]) + self.bnconvs = nn.ModuleList([ConvBnLelu(filters_in, filters_mid, kernel_size, bn=bn, bias=False)] + + [ConvBnLelu(filters_mid, filters_mid, kernel_size, bn=bn, bias=False) for i in range(depth-2)] + + [ConvBnLelu(filters_mid, filters_out, kernel_size, lelu=False, bn=False, bias=False)]) self.scale = nn.Parameter(torch.full((1,), fill_value=scale_init)) self.bias = nn.Parameter(torch.zeros(1))