More NSG improvements (v3)
Move to a fully fixup residual network for the switch (no batch norms). Fix a bunch of other small bugs. Add in a temporary latent feed-forward from the bottom of the switch. Fix several initialization issues.
This commit is contained in:
parent
4b82d0815d
commit
773753073f
|
@ -5,7 +5,60 @@ from switched_conv import BareConvSwitch, compute_attention_specificity
|
|||
from switched_conv_util import save_attention_to_image
|
||||
from functools import partial
|
||||
import torch.nn.functional as F
|
||||
from torchvision.models.resnet import BasicBlock, Bottleneck
|
||||
import numpy as np
|
||||
|
||||
|
||||
def conv3x3(in_planes, out_planes, stride=1):
|
||||
"""3x3 convolution with padding"""
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
||||
padding=1, bias=False)
|
||||
|
||||
|
||||
def conv1x1(in_planes, out_planes, stride=1):
|
||||
"""1x1 convolution"""
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
||||
|
||||
|
||||
# Taken from Fixup resnet implementation https://github.com/hongyi-zhang/Fixup/blob/master/imagenet/models/fixup_resnet_imagenet.py
|
||||
class FixupBottleneck(nn.Module):
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
||||
super(FixupBottleneck, self).__init__()
|
||||
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
|
||||
self.bias1a = nn.Parameter(torch.zeros(1))
|
||||
self.conv1 = conv1x1(inplanes, planes)
|
||||
self.bias1b = nn.Parameter(torch.zeros(1))
|
||||
self.bias2a = nn.Parameter(torch.zeros(1))
|
||||
self.conv2 = conv3x3(planes, planes, stride)
|
||||
self.bias2b = nn.Parameter(torch.zeros(1))
|
||||
self.bias3a = nn.Parameter(torch.zeros(1))
|
||||
self.conv3 = conv1x1(planes, planes * self.expansion)
|
||||
self.scale = nn.Parameter(torch.ones(1))
|
||||
self.bias3b = nn.Parameter(torch.zeros(1))
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
|
||||
out = self.conv1(x + self.bias1a)
|
||||
out = self.relu(out + self.bias1b)
|
||||
|
||||
out = self.conv2(out + self.bias2a)
|
||||
out = self.relu(out + self.bias2b)
|
||||
|
||||
out = self.conv3(out + self.bias3a)
|
||||
out = out * self.scale + self.bias3b
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x + self.bias1a)
|
||||
|
||||
out += identity
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Switch(nn.Module):
|
||||
|
@ -21,6 +74,10 @@ class Switch(nn.Module):
|
|||
self.scale = nn.Parameter(torch.ones(1))
|
||||
self.bias = nn.Parameter(torch.zeros(1))
|
||||
|
||||
if not self.pass_chain_forward:
|
||||
self.c_constric = MultiConvBlock(32, 32, 16, 3, 3)
|
||||
self.c_conjoin = ConvBnLelu(32, 16, kernel_size=1, bn=False)
|
||||
|
||||
# x is the input fed to the transform blocks.
|
||||
# m is the output of the multiplexer which will be used to select from those transform blocks.
|
||||
# chain is a chain of shared processing outputs used by the individual transforms.
|
||||
|
@ -30,11 +87,21 @@ class Switch(nn.Module):
|
|||
xformed = [o[0] for o in pcf]
|
||||
atts = [o[1] for o in pcf]
|
||||
else:
|
||||
# These adjustments were determined statistically from numeric_stability.py and should start this context
|
||||
# out in a normal distribution.
|
||||
context = (chain[-1] - 6) / 9.4
|
||||
context = F.pixel_shuffle(context, 4)
|
||||
context = self.c_constric(context)
|
||||
|
||||
context = F.interpolate(context, size=x.shape[2:], mode='nearest')
|
||||
context = torch.cat([x, context], dim=1)
|
||||
context = self.c_conjoin(context)
|
||||
|
||||
if self.add_noise:
|
||||
rand_feature = torch.randn_like(x)
|
||||
xformed = [t.forward(x, rand_feature) for t in self.transforms]
|
||||
xformed = [t.forward(context, rand_feature) for t in self.transforms]
|
||||
else:
|
||||
xformed = [t.forward(x) for t in self.transforms]
|
||||
xformed = [t.forward(context) for t in self.transforms]
|
||||
|
||||
# Interpolate the multiplexer across the entire shape of the image.
|
||||
m = F.interpolate(m, size=x.shape[2:], mode='nearest')
|
||||
|
@ -65,13 +132,13 @@ class Processor(nn.Module):
|
|||
|
||||
# Downsample block used for bottleneck.
|
||||
downsample = nn.Sequential(
|
||||
nn.Conv2d(base_filters, self.output_filter_count, kernel_size=1, stride=2),
|
||||
nn.Conv2d(base_filters, self.output_filter_count, kernel_size=1, stride=2, bias=False),
|
||||
nn.BatchNorm2d(self.output_filter_count),
|
||||
)
|
||||
# Bottleneck block outputs the requested filter sizex4, but we only want x2.
|
||||
self.initial = Bottleneck(base_filters, base_filters // 2, stride=2 if reduce else 1, downsample=downsample)
|
||||
self.initial = FixupBottleneck(base_filters, base_filters // 2, stride=2 if reduce else 1, downsample=downsample)
|
||||
|
||||
self.res_blocks = nn.ModuleList([BasicBlock(self.output_filter_count, self.output_filter_count) for _ in range(processing_depth)])
|
||||
self.res_blocks = nn.ModuleList([FixupBottleneck(self.output_filter_count, self.output_filter_count // 4) for _ in range(processing_depth)])
|
||||
|
||||
def forward(self, x):
|
||||
x = self.initial(x)
|
||||
|
@ -90,15 +157,7 @@ class Constrictor(nn.Module):
|
|||
gap_div_4 = int(gap / 4)
|
||||
self.cbl1 = ConvBnLelu(filters, filters - (gap_div_4 * 2), kernel_size=1, bn=True)
|
||||
self.cbl2 = ConvBnLelu(filters - (gap_div_4 * 2), filters - (gap_div_4 * 3), kernel_size=1, bn=True)
|
||||
self.cbl3 = nn.Conv2d(filters - (gap_div_4 * 3), output_filters, kernel_size=1)
|
||||
|
||||
# Init params.
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu')
|
||||
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
self.cbl3 = ConvBnLelu(filters - (gap_div_4 * 3), output_filters, kernel_size=1, lelu=False, bn=False)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.cbl1(x)
|
||||
|
@ -150,23 +209,19 @@ class NestedSwitchComputer(nn.Module):
|
|||
self.switch = RecursiveSwitchedTransform(transform_filters, filters, nesting_depth-1, transforms_at_leaf, trans_kernel_size, trans_num_layers-1, trans_scale_init, initial_temp=initial_temp, add_scalable_noise_to_transforms=add_scalable_noise_to_transforms)
|
||||
self.anneal = ConvBnLelu(transform_filters, transform_filters, kernel_size=1, bn=False)
|
||||
|
||||
# Init the parameters in the trunk.
|
||||
# Init the parameters in the trunk. Uses the fixup algorithm for residual conv initialization.
|
||||
self.num_layers = nesting_depth + nesting_depth * num_switch_processing_layers
|
||||
for m in self.processing_trunk.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
if isinstance(m, FixupBottleneck):
|
||||
nn.init.normal_(m.conv1.weight, mean=0, std=np.sqrt(2 / (m.conv1.weight.shape[0] * np.prod(m.conv1.weight.shape[2:]))) * self.num_layers ** (-0.25))
|
||||
nn.init.normal_(m.conv2.weight, mean=0, std=np.sqrt(2 / (m.conv2.weight.shape[0] * np.prod(m.conv2.weight.shape[2:]))) * self.num_layers ** (-0.25))
|
||||
nn.init.constant_(m.conv3.weight, 0)
|
||||
if m.downsample is not None:
|
||||
nn.init.normal_(m.downsample[0].weight, mean=0, std=np.sqrt(2 / (m.downsample[0].weight.shape[0] * np.prod(m.downsample[0].weight.shape[2:]))))
|
||||
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.kaiming_normal_(self.anneal.conv.weight, mode='fan_out', nonlinearity='leaky_relu')
|
||||
|
||||
# Zero-initialize the last BN in each residual branch,
|
||||
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
|
||||
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
|
||||
for m in self.processing_trunk.modules():
|
||||
if isinstance(m, Bottleneck):
|
||||
nn.init.constant_(m.bn3.weight, 0)
|
||||
elif isinstance(m, BasicBlock):
|
||||
nn.init.constant_(m.bn2.weight, 0)
|
||||
nn.init.kaiming_normal_(self.multiplexer_init_conv.weight, nonlinearity="relu")
|
||||
|
||||
def forward(self, x):
|
||||
trunk = []
|
||||
|
@ -175,6 +230,7 @@ class NestedSwitchComputer(nn.Module):
|
|||
trunk_input = m.forward(trunk_input)
|
||||
trunk.append(trunk_input)
|
||||
|
||||
self.trunk = (trunk[-1] - 6) / 9.4
|
||||
x, att = self.switch.forward(x, trunk)
|
||||
return self.anneal(x), att
|
||||
|
||||
|
@ -187,8 +243,8 @@ class NestedSwitchedGenerator(nn.Module):
|
|||
trans_layers, transformation_filters, initial_temp=20, final_temperature_step=50000, heightened_temp_min=1,
|
||||
heightened_final_step=50000, upsample_factor=1, add_scalable_noise_to_transforms=False):
|
||||
super(NestedSwitchedGenerator, self).__init__()
|
||||
self.initial_conv = ConvBnLelu(3, transformation_filters, kernel_size=7, bn=False)
|
||||
self.final_conv = ConvBnLelu(transformation_filters, 3, kernel_size=1, bn=False)
|
||||
self.initial_conv = ConvBnLelu(3, transformation_filters, kernel_size=7, lelu=False, bn=False)
|
||||
self.final_conv = ConvBnLelu(transformation_filters, 3, kernel_size=1, lelu=False, bn=False)
|
||||
|
||||
switches = []
|
||||
for sw_reduce, sw_proc, trans_count, kernel, layers in zip(switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes, trans_layers):
|
||||
|
@ -196,8 +252,6 @@ class NestedSwitchedGenerator(nn.Module):
|
|||
nesting_depth=sw_reduce, transforms_at_leaf=trans_count, trans_kernel_size=kernel, trans_num_layers=layers,
|
||||
trans_scale_init=.2/len(switch_reductions), initial_temp=initial_temp, add_scalable_noise_to_transforms=add_scalable_noise_to_transforms))
|
||||
self.switches = nn.ModuleList(switches)
|
||||
nn.init.kaiming_normal_(self.initial_conv.conv.weight, mode='fan_out', nonlinearity='leaky_relu')
|
||||
nn.init.kaiming_normal_(self.final_conv.conv.weight, mode='fan_in', nonlinearity='leaky_relu')
|
||||
|
||||
self.transformation_counts = trans_counts
|
||||
self.init_temperature = initial_temp
|
||||
|
@ -208,7 +262,6 @@ class NestedSwitchedGenerator(nn.Module):
|
|||
self.upsample_factor = upsample_factor
|
||||
|
||||
def forward(self, x):
|
||||
k = x
|
||||
# This network is entirely a "repair" network and operates on full-resolution images. Upsample first if that
|
||||
# is called for, then repair.
|
||||
if self.upsample_factor > 1:
|
||||
|
|
|
@ -106,10 +106,7 @@ class FixupResNet(nn.Module):
|
|||
nn.init.constant_(m.conv2.weight, 0)
|
||||
if m.downsample is not None:
|
||||
nn.init.normal_(m.downsample.weight, mean=0, std=np.sqrt(2 / (m.downsample.weight.shape[0] * np.prod(m.downsample.weight.shape[2:]))))
|
||||
'''
|
||||
elif isinstance(m, nn.Linear):
|
||||
nn.init.constant_(m.weight, 0)
|
||||
nn.init.constant_(m.bias, 0)'''
|
||||
|
||||
|
||||
def _make_layer(self, block, planes, blocks, stride=1, conv_type=conv3x3):
|
||||
defilter = None
|
||||
|
|
|
@ -8,6 +8,7 @@ from models.archs.arch_util import initialize_weights
|
|||
from switched_conv_util import save_attention_to_image
|
||||
|
||||
|
||||
''' Convenience class with Conv->BN->LeakyRelu. Includes Kaiming weight initialization. '''
|
||||
class ConvBnLelu(nn.Module):
|
||||
def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, lelu=True, bn=True):
|
||||
super(ConvBnLelu, self).__init__()
|
||||
|
@ -23,6 +24,14 @@ class ConvBnLelu(nn.Module):
|
|||
else:
|
||||
self.lelu = None
|
||||
|
||||
# Init params.
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, a=.1, mode='fan_out', nonlinearity='leaky_relu' if self.lelu else 'linear')
|
||||
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
if self.bn:
|
||||
|
@ -44,14 +53,6 @@ class MultiConvBlock(nn.Module):
|
|||
self.scale = nn.Parameter(torch.full((1,), fill_value=scale_init))
|
||||
self.bias = nn.Parameter(torch.zeros(1))
|
||||
|
||||
# Init params.
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='leaky_relu')
|
||||
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def forward(self, x, noise=None):
|
||||
if noise is not None:
|
||||
noise = noise * self.noise_scale
|
||||
|
|
Loading…
Reference in New Issue
Block a user