From 4b82d0815d373722afd690392f2610da86b29ecc Mon Sep 17 00:00:00 2001 From: James Betker Date: Mon, 29 Jun 2020 10:09:51 -0600 Subject: [PATCH] NSG improvements - Just use resnet blocks for the multiplexer trunk of the generator - Every block initializes itself, rather than everything at the end - Cleans up some messy parts of the architecture, including unnecessary kernel sizes and places where BN is not used properly. --- codes/models/archs/NestedSwitchGenerator.py | 74 ++++++++++++------- .../archs/SwitchedResidualGenerator_arch.py | 8 ++ 2 files changed, 56 insertions(+), 26 deletions(-) diff --git a/codes/models/archs/NestedSwitchGenerator.py b/codes/models/archs/NestedSwitchGenerator.py index 43372912..19be7599 100644 --- a/codes/models/archs/NestedSwitchGenerator.py +++ b/codes/models/archs/NestedSwitchGenerator.py @@ -1,10 +1,11 @@ import torch from torch import nn -from models.archs.SwitchedResidualGenerator_arch import ConvBnLelu, create_sequential_growing_processing_block, MultiConvBlock, initialize_weights +from models.archs.SwitchedResidualGenerator_arch import ConvBnLelu, MultiConvBlock, initialize_weights from switched_conv import BareConvSwitch, compute_attention_specificity from switched_conv_util import save_attention_to_image from functools import partial import torch.nn.functional as F +from torchvision.models.resnet import BasicBlock, Bottleneck class Switch(nn.Module): @@ -55,29 +56,22 @@ class Switch(nn.Module): [t.set_temperature(temp) for t in self.transforms] -class ResidualBlock(nn.Module): - def __init__(self, filters): - super(ResidualBlock, self).__init__() - self.lelu1 = nn.LeakyReLU(negative_slope=.1) - self.bn1 = nn.BatchNorm2d(filters) - self.conv1 = nn.Conv2d(filters, filters, kernel_size=3, padding=1) - self.lelu2 = nn.LeakyReLU(negative_slope=.1) - self.bn2 = nn.BatchNorm2d(filters) - self.conv2 = nn.Conv2d(filters, filters, kernel_size=3, padding=1) - - def forward(self, x): - x = self.conv1(self.lelu1(self.bn1(x))) - return self.conv2(self.lelu2(self.bn2(x))) - - # Convolutional image processing block that optionally reduces image size by a factor of 2 using stride and performs a # series of residual-block-like processing operations on it. class Processor(nn.Module): def __init__(self, base_filters, processing_depth, reduce=False): super(Processor, self).__init__() - self.output_filter_count = base_filters * 2 if reduce else base_filters - self.initial = ConvBnLelu(base_filters, self.output_filter_count, kernel_size=1, stride=2 if reduce else 1) - self.res_blocks = nn.ModuleList([ResidualBlock(self.output_filter_count) for _ in range(processing_depth)]) + self.output_filter_count = base_filters * 2 + + # Downsample block used for bottleneck. + downsample = nn.Sequential( + nn.Conv2d(base_filters, self.output_filter_count, kernel_size=1, stride=2), + nn.BatchNorm2d(self.output_filter_count), + ) + # Bottleneck block outputs the requested filter sizex4, but we only want x2. + self.initial = Bottleneck(base_filters, base_filters // 2, stride=2 if reduce else 1, downsample=downsample) + + self.res_blocks = nn.ModuleList([BasicBlock(self.output_filter_count, self.output_filter_count) for _ in range(processing_depth)]) def forward(self, x): x = self.initial(x) @@ -89,14 +83,22 @@ class Processor(nn.Module): # Convolutional image processing block that constricts an input image with a large number of filters to a small number # of filters over a fixed number of layers. class Constrictor(nn.Module): - def __init__(self, filters, output_filters, use_bn=False): + def __init__(self, filters, output_filters): super(Constrictor, self).__init__() assert(filters > output_filters) gap = filters - output_filters gap_div_4 = int(gap / 4) - self.cbl1 = ConvBnLelu(filters, filters - (gap_div_4 * 2), bn=use_bn) - self.cbl2 = ConvBnLelu(filters - (gap_div_4 * 2), filters - (gap_div_4 * 3), bn=use_bn) - self.cbl3 = ConvBnLelu(filters - (gap_div_4 * 3), output_filters) + self.cbl1 = ConvBnLelu(filters, filters - (gap_div_4 * 2), kernel_size=1, bn=True) + self.cbl2 = ConvBnLelu(filters - (gap_div_4 * 2), filters - (gap_div_4 * 3), kernel_size=1, bn=True) + self.cbl3 = nn.Conv2d(filters - (gap_div_4 * 3), output_filters, kernel_size=1) + + # Init params. + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) def forward(self, x): x = self.cbl1(x) @@ -148,6 +150,24 @@ class NestedSwitchComputer(nn.Module): self.switch = RecursiveSwitchedTransform(transform_filters, filters, nesting_depth-1, transforms_at_leaf, trans_kernel_size, trans_num_layers-1, trans_scale_init, initial_temp=initial_temp, add_scalable_noise_to_transforms=add_scalable_noise_to_transforms) self.anneal = ConvBnLelu(transform_filters, transform_filters, kernel_size=1, bn=False) + # Init the parameters in the trunk. + for m in self.processing_trunk.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + nn.init.kaiming_normal_(self.anneal.conv.weight, mode='fan_out', nonlinearity='leaky_relu') + + # Zero-initialize the last BN in each residual branch, + # so that the residual branch starts with zeros, and each residual block behaves like an identity. + # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 + for m in self.processing_trunk.modules(): + if isinstance(m, Bottleneck): + nn.init.constant_(m.bn3.weight, 0) + elif isinstance(m, BasicBlock): + nn.init.constant_(m.bn2.weight, 0) + def forward(self, x): trunk = [] trunk_input = self.multiplexer_init_conv(x) @@ -167,16 +187,17 @@ class NestedSwitchedGenerator(nn.Module): trans_layers, transformation_filters, initial_temp=20, final_temperature_step=50000, heightened_temp_min=1, heightened_final_step=50000, upsample_factor=1, add_scalable_noise_to_transforms=False): super(NestedSwitchedGenerator, self).__init__() - self.initial_conv = ConvBnLelu(3, transformation_filters, bn=False) - self.final_conv = ConvBnLelu(transformation_filters, 3, bn=False) + self.initial_conv = ConvBnLelu(3, transformation_filters, kernel_size=7, bn=False) + self.final_conv = ConvBnLelu(transformation_filters, 3, kernel_size=1, bn=False) switches = [] for sw_reduce, sw_proc, trans_count, kernel, layers in zip(switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes, trans_layers): switches.append(NestedSwitchComputer(transform_filters=transformation_filters, switch_base_filters=switch_filters, num_switch_processing_layers=sw_proc, nesting_depth=sw_reduce, transforms_at_leaf=trans_count, trans_kernel_size=kernel, trans_num_layers=layers, trans_scale_init=.2/len(switch_reductions), initial_temp=initial_temp, add_scalable_noise_to_transforms=add_scalable_noise_to_transforms)) - initialize_weights(switches, 1) self.switches = nn.ModuleList(switches) + nn.init.kaiming_normal_(self.initial_conv.conv.weight, mode='fan_out', nonlinearity='leaky_relu') + nn.init.kaiming_normal_(self.final_conv.conv.weight, mode='fan_in', nonlinearity='leaky_relu') self.transformation_counts = trans_counts self.init_temperature = initial_temp @@ -187,6 +208,7 @@ class NestedSwitchedGenerator(nn.Module): self.upsample_factor = upsample_factor def forward(self, x): + k = x # This network is entirely a "repair" network and operates on full-resolution images. Upsample first if that # is called for, then repair. if self.upsample_factor > 1: diff --git a/codes/models/archs/SwitchedResidualGenerator_arch.py b/codes/models/archs/SwitchedResidualGenerator_arch.py index ee53c1ce..4be9cb98 100644 --- a/codes/models/archs/SwitchedResidualGenerator_arch.py +++ b/codes/models/archs/SwitchedResidualGenerator_arch.py @@ -44,6 +44,14 @@ class MultiConvBlock(nn.Module): self.scale = nn.Parameter(torch.full((1,), fill_value=scale_init)) self.bias = nn.Parameter(torch.zeros(1)) + # Init params. + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='leaky_relu') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + def forward(self, x, noise=None): if noise is not None: noise = noise * self.noise_scale