From 703dec44723dda2077839704767d1a69c0dd2ad8 Mon Sep 17 00:00:00 2001 From: James Betker Date: Fri, 3 Jul 2020 12:07:31 -0600 Subject: [PATCH] Add SpineNet & integrate with SRG New version of SRG uses SpineNet for a switch backbone. --- .../archs/SwitchedResidualGenerator_arch.py | 181 ++++++---- codes/models/archs/spinenet_arch.py | 319 ++++++++++++++++++ codes/models/networks.py | 7 + codes/train.py | 4 +- codes/utils/numeric_stability.py | 16 +- 5 files changed, 444 insertions(+), 83 deletions(-) create mode 100644 codes/models/archs/spinenet_arch.py diff --git a/codes/models/archs/SwitchedResidualGenerator_arch.py b/codes/models/archs/SwitchedResidualGenerator_arch.py index 3eb647ed..47436bee 100644 --- a/codes/models/archs/SwitchedResidualGenerator_arch.py +++ b/codes/models/archs/SwitchedResidualGenerator_arch.py @@ -4,79 +4,11 @@ from switched_conv import BareConvSwitch, compute_attention_specificity import torch.nn.functional as F import functools from collections import OrderedDict -from models.archs.arch_util import initialize_weights +from models.archs.arch_util import initialize_weights, ConvBnRelu, ConvBnLelu from models.archs.RRDBNet_arch import ResidualDenseBlock_5C +from models.archs.spinenet_arch import SpineNet from switched_conv_util import save_attention_to_image -''' Convenience class with Conv->BN->ReLU. Includes weight initialization and auto-padding for standard - kernel sizes. ''' -class ConvBnRelu(nn.Module): - def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, relu=True, bn=True, bias=True): - super(ConvBnRelu, self).__init__() - padding_map = {1: 0, 3: 1, 5: 2, 7: 3} - assert kernel_size in padding_map.keys() - self.conv = nn.Conv2d(filters_in, filters_out, kernel_size, stride, padding_map[kernel_size], bias=bias) - if bn: - self.bn = nn.BatchNorm2d(filters_out) - else: - self.bn = None - if relu: - self.relu = nn.ReLU() - else: - self.relu = None - - # Init params. - for m in self.modules(): - if isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu' if self.relu else 'linear') - elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): - nn.init.constant_(m.weight, 1) - nn.init.constant_(m.bias, 0) - - def forward(self, x): - x = self.conv(x) - if self.bn: - x = self.bn(x) - if self.relu: - return self.relu(x) - else: - return x - - -''' Convenience class with Conv->BN->ReLU. Includes weight initialization and auto-padding for standard - kernel sizes. ''' -class ConvBnLelu(nn.Module): - def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, lelu=True, bn=True, bias=True): - super(ConvBnLelu, self).__init__() - padding_map = {1: 0, 3: 1, 5: 2, 7: 3} - assert kernel_size in padding_map.keys() - self.conv = nn.Conv2d(filters_in, filters_out, kernel_size, stride, padding_map[kernel_size], bias=bias) - if bn: - self.bn = nn.BatchNorm2d(filters_out) - else: - self.bn = None - if lelu: - self.lelu = nn.LeakyReLU(negative_slope=.1) - else: - self.lelu = None - - # Init params. - for m in self.modules(): - if isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_(m.weight, a=.1, mode='fan_out', nonlinearity='leaky_relu' if self.lelu else 'linear') - elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): - nn.init.constant_(m.weight, 1) - nn.init.constant_(m.bias, 0) - - def forward(self, x): - x = self.conv(x) - if self.bn: - x = self.bn(x) - if self.lelu: - return self.lelu(x) - else: - return x - class MultiConvBlock(nn.Module): def __init__(self, filters_in, filters_mid, filters_out, kernel_size, depth, scale_init=1, bn=False): @@ -214,7 +146,7 @@ class ConfigurableSwitchComputer(nn.Module): m = self.multiplexer(identity) # Interpolate the multiplexer across the entire shape of the image. - m = F.interpolate(m, size=x.shape[2:], mode='nearest') + m = F.interpolate(m, size=xformed[0].shape[2:], mode='nearest') outputs, attention = self.switch(xformed, m, True) outputs = identity + outputs * self.switch_scale @@ -252,6 +184,22 @@ class ConvBasisMultiplexer(nn.Module): return x +class SpineNetMultiplexer(nn.Module): + def __init__(self, input_channels, transform_count): + super(SpineNetMultiplexer, self).__init__() + self.backbone = SpineNet('49', in_channels=input_channels) + self.rdc1 = ConvBnRelu(256, 128, kernel_size=3, bias=False) + self.rdc2 = ConvBnRelu(128, 64, kernel_size=3, bias=False) + self.rdc3 = ConvBnRelu(64, transform_count, bias=False, bn=False, relu=False) + + def forward(self, x): + spine = self.backbone(x) + feat = self.rdc1(spine[0]) + feat = self.rdc2(feat) + feat = self.rdc3(feat) + return feat + + class ConvBasisMultiplexerReducer(nn.Module): def __init__(self, input_channels, base_filters, growth, reductions, processing_depth): super(ConvBasisMultiplexerReducer, self).__init__() @@ -415,6 +363,97 @@ class ConfigurableSwitchedResidualGenerator2(nn.Module): if step % 50 == 0: [save_attention_to_image(experiments_path, self.attentions[i], self.transformation_counts[i], step, "a%i" % (i+1,)) for i in range(len(self.switches))] + def get_debug_values(self, step): + temp = self.switches[0].switch.temperature + mean_hists = [compute_attention_specificity(att, 2) for att in self.attentions] + means = [i[0] for i in mean_hists] + hists = [i[1].clone().detach().cpu().flatten() for i in mean_hists] + val = {"switch_temperature": temp} + for i in range(len(means)): + val["switch_%i_specificity" % (i,)] = means[i] + val["switch_%i_histogram" % (i,)] = hists[i] + return val + + +class Interpolate(nn.Module): + def __init__(self, factor): + super(Interpolate, self).__init__() + self.factor = factor + + def forward(self, x): + return F.interpolate(x, scale_factor=self.factor) + + +class ConfigurableSwitchedResidualGenerator3(nn.Module): + def __init__(self, trans_counts, + trans_kernel_sizes, + trans_layers, transformation_filters, initial_temp=20, final_temperature_step=50000, + heightened_temp_min=1, + heightened_final_step=50000, upsample_factor=1, enable_negative_transforms=False, + add_scalable_noise_to_transforms=False): + super(ConfigurableSwitchedResidualGenerator3, self).__init__() + switches = [] + for trans_count, kernel, layers in zip(trans_counts, trans_kernel_sizes, trans_layers): + multiplx_fn = functools.partial(SpineNetMultiplexer, 3) + switches.append(ConfigurableSwitchComputer(base_filters=3, multiplexer_net=multiplx_fn, + pre_transform_block=functools.partial(nn.Sequential, + ConvBnLelu(3, transformation_filters, kernel_size=1, stride=4, bn=False, lelu=False, bias=False), + ResidualDenseBlock_5C( + transformation_filters), + ResidualDenseBlock_5C( + transformation_filters)), + transform_block=functools.partial(nn.Sequential, + ResidualDenseBlock_5C(transformation_filters), + Interpolate(4), + ConvBnLelu(transformation_filters, transformation_filters // 2, kernel_size=3, bias=False, bn=False), + ConvBnLelu(transformation_filters // 2, 3, kernel_size=1, bias=False, bn=False, lelu=False)), + transform_count=trans_count, init_temp=initial_temp, + enable_negative_transforms=enable_negative_transforms, + add_scalable_noise_to_transforms=add_scalable_noise_to_transforms, + init_scalar=.01)) + + self.switches = nn.ModuleList(switches) + self.transformation_counts = trans_counts + self.init_temperature = initial_temp + self.final_temperature_step = final_temperature_step + self.heightened_temp_min = heightened_temp_min + self.heightened_final_step = heightened_final_step + self.attentions = None + self.upsample_factor = upsample_factor + + def forward(self, x): + if self.upsample_factor > 1: + x = F.interpolate(x, scale_factor=self.upsample_factor, mode="nearest") + + self.attentions = [] + for i, sw in enumerate(self.switches): + x, att = sw.forward(x, True) + self.attentions.append(att) + + return x, + + def set_temperature(self, temp): + [sw.set_temperature(temp) for sw in self.switches] + + def update_for_step(self, step, experiments_path='.'): + if self.attentions: + temp = max(1, int( + self.init_temperature * (self.final_temperature_step - step) / self.final_temperature_step)) + if temp == 1 and self.heightened_final_step and self.heightened_final_step != 1: + # Once the temperature passes (1) it enters an inverted curve to match the linear curve from above. + # without this, the attention specificity "spikes" incredibly fast in the last few iterations. + h_steps_total = self.heightened_final_step - self.final_temperature_step + h_steps_current = min(step - self.final_temperature_step, h_steps_total) + # The "gap" will represent the steps that need to be traveled as a linear function. + h_gap = 1 / self.heightened_temp_min + temp = h_gap * h_steps_current / h_steps_total + # Invert temperature to represent reality on this side of the curve + temp = 1 / temp + self.set_temperature(temp) + if step % 50 == 0: + [save_attention_to_image(experiments_path, self.attentions[i], self.transformation_counts[i], step, + "a%i" % (i + 1,)) for i in range(len(self.switches))] + def get_debug_values(self, step): temp = self.switches[0].switch.temperature mean_hists = [compute_attention_specificity(att, 2) for att in self.attentions] diff --git a/codes/models/archs/spinenet_arch.py b/codes/models/archs/spinenet_arch.py new file mode 100644 index 00000000..7fcd9b17 --- /dev/null +++ b/codes/models/archs/spinenet_arch.py @@ -0,0 +1,319 @@ +# Taken and modified from https://github.com/lucifer443/SpineNet-Pytorch/blob/master/mmdet/models/backbones/spinenet.py + +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.init import kaiming_normal + +from torchvision.models.resnet import BasicBlock, Bottleneck +from torch.nn.modules.batchnorm import _BatchNorm +from models.archs.arch_util import ConvBnRelu + +def constant_init(module, val, bias=0): + if hasattr(module, 'weight') and module.weight is not None: + nn.init.constant_(module.weight, val) + if hasattr(module, 'bias') and module.bias is not None: + nn.init.constant_(module.bias, bias) + +def kaiming_init(module, + a=0, + mode='fan_out', + nonlinearity='relu', + bias=0, + distribution='normal'): + assert distribution in ['uniform', 'normal'] + if distribution == 'uniform': + nn.init.kaiming_uniform_( + module.weight, a=a, mode=mode, nonlinearity=nonlinearity) + else: + nn.init.kaiming_normal_( + module.weight, a=a, mode=mode, nonlinearity=nonlinearity) + if hasattr(module, 'bias') and module.bias is not None: + nn.init.constant_(module.bias, bias) + +FILTER_SIZE_MAP = { + 1: 32, + 2: 64, + 3: 128, + 4: 256, + 5: 256, + 6: 256, + 7: 256, +} + +def make_res_layer(block, + inplanes, + planes, + blocks, + stride=1, + dilation=1): + downsample = None + if stride != 1 or inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d( + inplanes, + planes * block.expansion, + kernel_size=1, + stride=stride, + bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append( + block( + inplanes=inplanes, + planes=planes, + stride=stride, + dilation=dilation, + downsample=downsample)) + inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append( + block( + inplanes=inplanes, + planes=planes, + stride=1, + dilation=dilation)) + + return nn.Sequential(*layers) + +# The fixed SpineNet architecture discovered by NAS. +# Each element represents a specification of a building block: +# (block_level, block_fn, (input_offset0, input_offset1), is_output). +SPINENET_BLOCK_SPECS = [ + (2, Bottleneck, (None, None), False), # init block + (2, Bottleneck, (None, None), False), # init block + (2, Bottleneck, (0, 1), False), + (4, BasicBlock, (0, 1), False), + (3, Bottleneck, (2, 3), False), + (4, Bottleneck, (2, 4), False), + (6, BasicBlock, (3, 5), False), + (4, Bottleneck, (3, 5), False), + (5, BasicBlock, (6, 7), False), + (7, BasicBlock, (6, 8), False), + (5, Bottleneck, (8, 9), False), + (5, Bottleneck, (8, 10), False), + (4, Bottleneck, (5, 10), True), + (3, Bottleneck, (4, 10), True), + (5, Bottleneck, (7, 12), True), + (7, Bottleneck, (5, 14), True), + (6, Bottleneck, (12, 14), True), +] + +SCALING_MAP = { + '49S': { + 'endpoints_num_filters': 128, + 'filter_size_scale': 0.65, + 'resample_alpha': 0.5, + 'block_repeats': 1, + }, + '49': { + 'endpoints_num_filters': 256, + 'filter_size_scale': 1.0, + 'resample_alpha': 0.5, + 'block_repeats': 1, + }, + '96': { + 'endpoints_num_filters': 256, + 'filter_size_scale': 1.0, + 'resample_alpha': 0.5, + 'block_repeats': 2, + }, + '143': { + 'endpoints_num_filters': 256, + 'filter_size_scale': 1.0, + 'resample_alpha': 1.0, + 'block_repeats': 3, + }, + '190': { + 'endpoints_num_filters': 512, + 'filter_size_scale': 1.3, + 'resample_alpha': 1.0, + 'block_repeats': 4, + }, +} + + +class BlockSpec(object): + """A container class that specifies the block configuration for SpineNet.""" + + def __init__(self, level, block_fn, input_offsets, is_output): + self.level = level + self.block_fn = block_fn + self.input_offsets = input_offsets + self.is_output = is_output + + +def build_block_specs(block_specs=None): + """Builds the list of BlockSpec objects for SpineNet.""" + if not block_specs: + block_specs = SPINENET_BLOCK_SPECS + return [BlockSpec(*b) for b in block_specs] + + +class Resample(nn.Module): + def __init__(self, in_channels, out_channels, scale, block_type, alpha=1.0): + super(Resample, self).__init__() + self.scale = scale + new_in_channels = int(in_channels * alpha) + if block_type == Bottleneck: + in_channels *= 4 + self.squeeze_conv = ConvBnRelu(in_channels, new_in_channels, kernel_size=1) + if scale < 1: + self.downsample_conv = ConvBnRelu(new_in_channels, new_in_channels, kernel_size=3, stride=2) + self.expand_conv = ConvBnRelu(new_in_channels, out_channels, kernel_size=1, relu=False) + + def _resize(self, x): + if self.scale == 1: + return x + elif self.scale > 1: + return F.interpolate(x, scale_factor=self.scale, mode='nearest') + else: + x = self.downsample_conv(x) + if self.scale < 0.5: + new_kernel_size = 3 if self.scale >= 0.25 else 5 + x = F.max_pool2d(x, kernel_size=new_kernel_size, stride=int(0.5/self.scale), padding=new_kernel_size//2) + return x + + def forward(self, inputs): + feat = self.squeeze_conv(inputs) + feat = self._resize(feat) + feat = self.expand_conv(feat) + return feat + + +class Merge(nn.Module): + """Merge two input tensors""" + def __init__(self, block_spec, alpha, filter_size_scale): + super(Merge, self).__init__() + out_channels = int(FILTER_SIZE_MAP[block_spec.level] * filter_size_scale) + if block_spec.block_fn == Bottleneck: + out_channels *= 4 + self.block = block_spec.block_fn + self.resample_ops = nn.ModuleList() + for spec_idx in block_spec.input_offsets: + spec = BlockSpec(*SPINENET_BLOCK_SPECS[spec_idx]) + in_channels = int(FILTER_SIZE_MAP[spec.level] * filter_size_scale) + scale = 2**(spec.level - block_spec.level) + self.resample_ops.append( + Resample(in_channels, out_channels, scale, spec.block_fn, alpha) + ) + + def forward(self, inputs): + assert len(inputs) == len(self.resample_ops) + parent0_feat = self.resample_ops[0](inputs[0]) + parent1_feat = self.resample_ops[1](inputs[1]) + target_feat = parent0_feat + parent1_feat + return target_feat + + +class SpineNet(nn.Module): + """Class to build SpineNet backbone""" + def __init__(self, + arch, + in_channels=3, + output_level=[3, 4, 5, 6, 7], + zero_init_residual=True): + super(SpineNet, self).__init__() + self._block_specs = build_block_specs()[2:] + self._endpoints_num_filters = SCALING_MAP[arch]['endpoints_num_filters'] + self._resample_alpha = SCALING_MAP[arch]['resample_alpha'] + self._block_repeats = SCALING_MAP[arch]['block_repeats'] + self._filter_size_scale = SCALING_MAP[arch]['filter_size_scale'] + self._init_block_fn = Bottleneck + self._num_init_blocks = 2 + self.zero_init_residual = zero_init_residual + assert min(output_level) > 2 and max(output_level) < 8, "Output level out of range" + self.output_level = output_level + + self._make_stem_layer(in_channels) + self._make_scale_permuted_network() + self._make_endpoints() + + def _make_stem_layer(self, in_channels): + """Build the stem network.""" + # Build the first conv and maxpooling layers. + self.conv1 = ConvBnRelu( + in_channels, + 64, + kernel_size=7, + stride=2) # Original paper had stride=2 and a maxpool after. + + # Build the initial level 2 blocks. + self.init_block1 = make_res_layer( + self._init_block_fn, + 64, + int(FILTER_SIZE_MAP[2] * self._filter_size_scale), + self._block_repeats) + self.init_block2 = make_res_layer( + self._init_block_fn, + int(FILTER_SIZE_MAP[2] * self._filter_size_scale) * 4, + int(FILTER_SIZE_MAP[2] * self._filter_size_scale), + self._block_repeats) + + def _make_endpoints(self): + self.endpoint_convs = nn.ModuleDict() + for block_spec in self._block_specs: + if block_spec.is_output: + in_channels = int(FILTER_SIZE_MAP[block_spec.level]*self._filter_size_scale) * 4 + self.endpoint_convs[str(block_spec.level)] = ConvBnRelu(in_channels, + self._endpoints_num_filters, + kernel_size=1, + relu=False) + + def _make_scale_permuted_network(self): + self.merge_ops = nn.ModuleList() + self.scale_permuted_blocks = nn.ModuleList() + for spec in self._block_specs: + self.merge_ops.append( + Merge(spec, self._resample_alpha, self._filter_size_scale) + ) + channels = int(FILTER_SIZE_MAP[spec.level] * self._filter_size_scale) + in_channels = channels * 4 if spec.block_fn == Bottleneck else channels + self.scale_permuted_blocks.append( + make_res_layer(spec.block_fn, + in_channels, + channels, + self._block_repeats) + ) + + def init_weights(self, pretrained=None): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + kaiming_init(m) + elif isinstance(m, (_BatchNorm, nn.GroupNorm)): + constant_init(m, 1) + if self.zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + constant_init(m.norm3, 0) + elif isinstance(m, BasicBlock): + constant_init(m.norm2, 0) + + def forward(self, input): + feat = self.conv1(input) + feat1 = self.init_block1(feat) + feat2 = self.init_block2(feat1) + block_feats = [feat1, feat2] + output_feat = {} + num_outgoing_connections = [0, 0] + + for i, spec in enumerate(self._block_specs): + target_feat = self.merge_ops[i]([block_feats[feat_idx] for feat_idx in spec.input_offsets]) + # Connect intermediate blocks with outdegree 0 to the output block. + if spec.is_output: + for j, (j_feat, j_connections) in enumerate( + zip(block_feats, num_outgoing_connections)): + if j_connections == 0 and j_feat.shape == target_feat.shape: + target_feat += j_feat + num_outgoing_connections[j] += 1 + target_feat = F.relu(target_feat, inplace=True) + target_feat = self.scale_permuted_blocks[i](target_feat) + block_feats.append(target_feat) + num_outgoing_connections.append(0) + for feat_idx in spec.input_offsets: + num_outgoing_connections[feat_idx] += 1 + if spec.is_output: + output_feat[spec.level] = target_feat + + return [self.endpoint_convs[str(level)](output_feat[level]) for level in self.output_level] \ No newline at end of file diff --git a/codes/models/networks.py b/codes/models/networks.py index e6f32efc..13264a87 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -66,6 +66,13 @@ def define_G(opt, net_key='network_G'): initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'], heightened_temp_min=opt_net['heightened_temp_min'], heightened_final_step=opt_net['heightened_final_step'], upsample_factor=scale, add_scalable_noise_to_transforms=opt_net['add_noise']) + elif which_model == "ConfigurableSwitchedResidualGenerator3": + netG = SwitchedGen_arch.ConfigurableSwitchedResidualGenerator3(trans_counts=opt_net['trans_counts'], + trans_kernel_sizes=opt_net['trans_kernel_sizes'], trans_layers=opt_net['trans_layers'], + transformation_filters=opt_net['transformation_filters'], + initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'], + heightened_temp_min=opt_net['heightened_temp_min'], heightened_final_step=opt_net['heightened_final_step'], + upsample_factor=scale, add_scalable_noise_to_transforms=opt_net['add_noise']) elif which_model == "NestedSwitchGenerator": netG = ng.NestedSwitchedGenerator(switch_filters=opt_net['switch_filters'], switch_reductions=opt_net['switch_reductions'], diff --git a/codes/train.py b/codes/train.py index b2b87afc..8c7c0dc8 100644 --- a/codes/train.py +++ b/codes/train.py @@ -33,7 +33,7 @@ def init_dist(backend='nccl', **kwargs): def main(): #### options parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_div2k_feat_resgen2_lr.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_div2k_srg3.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) @@ -162,7 +162,7 @@ def main(): current_step = resume_state['iter'] model.resume_training(resume_state) # handle optimizers and schedulers else: - current_step = -1 + current_step = 0 start_epoch = 0 #### training diff --git a/codes/utils/numeric_stability.py b/codes/utils/numeric_stability.py index 3e7855f9..30f4c4c9 100644 --- a/codes/utils/numeric_stability.py +++ b/codes/utils/numeric_stability.py @@ -4,7 +4,7 @@ import models.archs.SwitchedResidualGenerator_arch as srg import models.archs.NestedSwitchGenerator as nsg import functools -blacklisted_modules = [nn.Conv2d, nn.ReLU, nn.LeakyReLU, nn.BatchNorm2d, nn.Softmax] +blacklisted_modules = [nn.Conv2d, nn.ReLU, nn.LeakyReLU, nn.BatchNorm2d, nn.Softmax, srg.Interpolate] def install_forward_trace_hooks(module, id="base"): if type(module) in blacklisted_modules: return @@ -96,15 +96,11 @@ if __name__ == "__main__": torch.randn(1, 3, 64, 64), device='cuda') ''' - test_stability(functools.partial(srg.ConfigurableSwitchedResidualGenerator2, - switch_filters=[16,16,16,16,16], - switch_growths=[32,32,32,32,32], - switch_reductions=[1,1,1,1,1], - switch_processing_layers=[5,5,5,5,5], - trans_counts=[8,8,8,8,8], - trans_kernel_sizes=[3,3,3,3,3], - trans_layers=[3,3,3,3,3], + test_stability(functools.partial(srg.ConfigurableSwitchedResidualGenerator3, + trans_counts=[8], + trans_kernel_sizes=[3], + trans_layers=[3], transformation_filters=64, initial_temp=10), - torch.randn(1, 3, 64, 64), + torch.randn(1, 3, 128, 128), device='cuda')