forked from mrq/DL-Art-School
NSG r7
Converts the switching trunk to a VGG-style network to make it more comparable to SRG architectures.
This commit is contained in:
parent
87f1e9c56f
commit
604763be68
|
@ -1,64 +1,11 @@
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from models.archs.SwitchedResidualGenerator_arch import ConvBnLelu, MultiConvBlock, initialize_weights
|
from models.archs.SwitchedResidualGenerator_arch import ConvBnLelu, ConvBnRelu, MultiConvBlock, initialize_weights
|
||||||
from switched_conv import BareConvSwitch, compute_attention_specificity
|
from switched_conv import BareConvSwitch, compute_attention_specificity
|
||||||
from switched_conv_util import save_attention_to_image
|
from switched_conv_util import save_attention_to_image
|
||||||
from functools import partial
|
from functools import partial
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import numpy as np
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
|
||||||
def conv3x3(in_planes, out_planes, stride=1):
|
|
||||||
"""3x3 convolution with padding"""
|
|
||||||
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
|
||||||
padding=1, bias=False)
|
|
||||||
|
|
||||||
|
|
||||||
def conv1x1(in_planes, out_planes, stride=1):
|
|
||||||
"""1x1 convolution"""
|
|
||||||
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
|
||||||
|
|
||||||
|
|
||||||
# Taken from Fixup resnet implementation https://github.com/hongyi-zhang/Fixup/blob/master/imagenet/models/fixup_resnet_imagenet.py
|
|
||||||
class FixupBottleneck(nn.Module):
|
|
||||||
expansion = 4
|
|
||||||
|
|
||||||
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
|
||||||
super(FixupBottleneck, self).__init__()
|
|
||||||
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
|
|
||||||
self.bias1a = nn.Parameter(torch.zeros(1))
|
|
||||||
self.conv1 = conv1x1(inplanes, planes)
|
|
||||||
self.bias1b = nn.Parameter(torch.zeros(1))
|
|
||||||
self.bias2a = nn.Parameter(torch.zeros(1))
|
|
||||||
self.conv2 = conv3x3(planes, planes, stride)
|
|
||||||
self.bias2b = nn.Parameter(torch.zeros(1))
|
|
||||||
self.bias3a = nn.Parameter(torch.zeros(1))
|
|
||||||
self.conv3 = conv1x1(planes, planes * self.expansion)
|
|
||||||
self.scale = nn.Parameter(torch.ones(1))
|
|
||||||
self.bias3b = nn.Parameter(torch.zeros(1))
|
|
||||||
self.relu = nn.ReLU(inplace=True)
|
|
||||||
self.downsample = downsample
|
|
||||||
self.stride = stride
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
identity = x
|
|
||||||
|
|
||||||
out = self.conv1(x + self.bias1a)
|
|
||||||
out = self.relu(out + self.bias1b)
|
|
||||||
|
|
||||||
out = self.conv2(out + self.bias2a)
|
|
||||||
out = self.relu(out + self.bias2b)
|
|
||||||
|
|
||||||
out = self.conv3(out + self.bias3a)
|
|
||||||
out = out * self.scale + self.bias3b
|
|
||||||
|
|
||||||
if self.downsample is not None:
|
|
||||||
identity = self.downsample(x + self.bias1a)
|
|
||||||
|
|
||||||
out += identity
|
|
||||||
out = self.relu(out)
|
|
||||||
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
class Switch(nn.Module):
|
class Switch(nn.Module):
|
||||||
|
@ -110,30 +57,21 @@ class Switch(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
# Convolutional image processing block that optionally reduces image size by a factor of 2 using stride and performs a
|
# Convolutional image processing block that optionally reduces image size by a factor of 2 using stride and performs a
|
||||||
# series of residual-block-like processing operations on it.
|
# series of conv blocks on it
|
||||||
class Processor(nn.Module):
|
class Processor(nn.Module):
|
||||||
def __init__(self, base_filters, processing_depth, reduce=False):
|
def __init__(self, base_filters, processing_depth, reduce=False):
|
||||||
super(Processor, self).__init__()
|
super(Processor, self).__init__()
|
||||||
self.output_filter_count = base_filters * (2 if reduce else 1)
|
self.output_filter_count = base_filters * (2 if reduce else 1)
|
||||||
|
self.pre = ConvBnRelu(base_filters, base_filters, kernel_size=3, bias=True)
|
||||||
# Downsample block used for bottleneck.
|
self.initial = ConvBnRelu(base_filters, self.output_filter_count, kernel_size=1, stride=2 if reduce else 1, bias=False)
|
||||||
if reduce:
|
self.blocks = nn.Sequential(OrderedDict(
|
||||||
downsample = nn.Sequential(
|
[(str(i), ConvBnRelu(self.output_filter_count, self.output_filter_count, kernel_size=3, bias=False)) for i in range(processing_depth)]))
|
||||||
nn.Conv2d(base_filters, self.output_filter_count, kernel_size=1, stride=2, bias=False),
|
|
||||||
nn.BatchNorm2d(self.output_filter_count),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
downsample = None
|
|
||||||
# Bottleneck block outputs the requested filter sizex4, but we only want x2.
|
|
||||||
self.initial = FixupBottleneck(base_filters, self.output_filter_count // 4, stride=2 if reduce else 1, downsample=downsample)
|
|
||||||
self.res_blocks = nn.ModuleList([FixupBottleneck(self.output_filter_count, self.output_filter_count // 4) for _ in range(processing_depth)])
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = (self.initial(x) - .4) / .6
|
x = self.pre(x)
|
||||||
for b in self.res_blocks:
|
x = self.initial(x)
|
||||||
r = (b(x) - .4) / .6
|
x = self.blocks(x)
|
||||||
x = r + x
|
return (x - .39) / .58
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
# Convolutional image processing block that constricts an input image with a large number of filters to a small number
|
# Convolutional image processing block that constricts an input image with a large number of filters to a small number
|
||||||
|
@ -144,15 +82,15 @@ class Constrictor(nn.Module):
|
||||||
assert(filters > output_filters)
|
assert(filters > output_filters)
|
||||||
gap = filters - output_filters
|
gap = filters - output_filters
|
||||||
gap_div_4 = int(gap / 4)
|
gap_div_4 = int(gap / 4)
|
||||||
self.cbl1 = ConvBnLelu(filters, filters - (gap_div_4 * 2), kernel_size=1, bn=True)
|
self.cbl1 = ConvBnRelu(filters, filters - (gap_div_4 * 2), kernel_size=1, bn=True, bias=True)
|
||||||
self.cbl2 = ConvBnLelu(filters - (gap_div_4 * 2), filters - (gap_div_4 * 3), kernel_size=1, bn=True)
|
self.cbl2 = ConvBnRelu(filters - (gap_div_4 * 2), filters - (gap_div_4 * 3), kernel_size=1, bn=True, bias=False)
|
||||||
self.cbl3 = ConvBnLelu(filters - (gap_div_4 * 3), output_filters, kernel_size=1, lelu=False, bn=False)
|
self.cbl3 = ConvBnRelu(filters - (gap_div_4 * 3), output_filters, kernel_size=1, relu=False, bn=False, bias=False)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.cbl1(x)
|
x = self.cbl1(x)
|
||||||
x = self.cbl2(x)
|
x = self.cbl2(x)
|
||||||
x = self.cbl3(x) / 4
|
x = self.cbl3(x)
|
||||||
return x
|
return x / 2.67
|
||||||
|
|
||||||
|
|
||||||
class RecursiveSwitchedTransform(nn.Module):
|
class RecursiveSwitchedTransform(nn.Module):
|
||||||
|
@ -200,19 +138,6 @@ class NestedSwitchComputer(nn.Module):
|
||||||
self.switch = RecursiveSwitchedTransform(transform_filters, filters, nesting_depth-1, transforms_at_leaf, trans_kernel_size, trans_num_layers-1, trans_scale_init, initial_temp=initial_temp, add_scalable_noise_to_transforms=add_scalable_noise_to_transforms)
|
self.switch = RecursiveSwitchedTransform(transform_filters, filters, nesting_depth-1, transforms_at_leaf, trans_kernel_size, trans_num_layers-1, trans_scale_init, initial_temp=initial_temp, add_scalable_noise_to_transforms=add_scalable_noise_to_transforms)
|
||||||
self.anneal = ConvBnLelu(transform_filters, transform_filters, kernel_size=1, bn=False)
|
self.anneal = ConvBnLelu(transform_filters, transform_filters, kernel_size=1, bn=False)
|
||||||
|
|
||||||
# Init the parameters in the trunk. Uses the fixup algorithm for residual conv initialization.
|
|
||||||
self.num_layers = nesting_depth + nesting_depth * num_switch_processing_layers
|
|
||||||
for m in self.processing_trunk.modules():
|
|
||||||
if isinstance(m, FixupBottleneck):
|
|
||||||
nn.init.normal_(m.conv1.weight, mean=0, std=np.sqrt(2 / (m.conv1.weight.shape[0] * np.prod(m.conv1.weight.shape[2:]))) * self.num_layers ** (-0.25))
|
|
||||||
nn.init.normal_(m.conv2.weight, mean=0, std=np.sqrt(2 / (m.conv2.weight.shape[0] * np.prod(m.conv2.weight.shape[2:]))) * self.num_layers ** (-0.25))
|
|
||||||
nn.init.constant_(m.conv3.weight, 0)
|
|
||||||
if m.downsample is not None:
|
|
||||||
nn.init.normal_(m.downsample[0].weight, mean=0, std=np.sqrt(2 / (m.downsample[0].weight.shape[0] * np.prod(m.downsample[0].weight.shape[2:]))))
|
|
||||||
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
|
||||||
nn.init.constant_(m.weight, 1)
|
|
||||||
nn.init.constant_(m.bias, 0)
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
feed_forward = x
|
feed_forward = x
|
||||||
trunk = []
|
trunk = []
|
||||||
|
|
|
@ -7,14 +7,49 @@ from collections import OrderedDict
|
||||||
from models.archs.arch_util import initialize_weights
|
from models.archs.arch_util import initialize_weights
|
||||||
from switched_conv_util import save_attention_to_image
|
from switched_conv_util import save_attention_to_image
|
||||||
|
|
||||||
|
''' Convenience class with Conv->BN->ReLU. Includes weight initialization and auto-padding for standard
|
||||||
|
kernel sizes. '''
|
||||||
|
class ConvBnRelu(nn.Module):
|
||||||
|
def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, relu=True, bn=True, bias=True):
|
||||||
|
super(ConvBnRelu, self).__init__()
|
||||||
|
padding_map = {1: 0, 3: 1, 5: 2, 7: 3}
|
||||||
|
assert kernel_size in padding_map.keys()
|
||||||
|
self.conv = nn.Conv2d(filters_in, filters_out, kernel_size, stride, padding_map[kernel_size], bias=bias)
|
||||||
|
if bn:
|
||||||
|
self.bn = nn.BatchNorm2d(filters_out)
|
||||||
|
else:
|
||||||
|
self.bn = None
|
||||||
|
if relu:
|
||||||
|
self.relu = nn.ReLU()
|
||||||
|
else:
|
||||||
|
self.relu = None
|
||||||
|
|
||||||
''' Convenience class with Conv->BN->LeakyRelu. Includes Kaiming weight initialization. '''
|
# Init params.
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, nn.Conv2d):
|
||||||
|
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu' if self.relu else 'linear')
|
||||||
|
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
||||||
|
nn.init.constant_(m.weight, 1)
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.conv(x)
|
||||||
|
if self.bn:
|
||||||
|
x = self.bn(x)
|
||||||
|
if self.relu:
|
||||||
|
return self.relu(x)
|
||||||
|
else:
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
''' Convenience class with Conv->BN->ReLU. Includes weight initialization and auto-padding for standard
|
||||||
|
kernel sizes. '''
|
||||||
class ConvBnLelu(nn.Module):
|
class ConvBnLelu(nn.Module):
|
||||||
def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, lelu=True, bn=True):
|
def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, lelu=True, bn=True, bias=True):
|
||||||
super(ConvBnLelu, self).__init__()
|
super(ConvBnLelu, self).__init__()
|
||||||
padding_map = {1: 0, 3: 1, 5: 2, 7: 3}
|
padding_map = {1: 0, 3: 1, 5: 2, 7: 3}
|
||||||
assert kernel_size in padding_map.keys()
|
assert kernel_size in padding_map.keys()
|
||||||
self.conv = nn.Conv2d(filters_in, filters_out, kernel_size, stride, padding_map[kernel_size])
|
self.conv = nn.Conv2d(filters_in, filters_out, kernel_size, stride, padding_map[kernel_size], bias=bias)
|
||||||
if bn:
|
if bn:
|
||||||
self.bn = nn.BatchNorm2d(filters_out)
|
self.bn = nn.BatchNorm2d(filters_out)
|
||||||
else:
|
else:
|
||||||
|
@ -47,9 +82,9 @@ class MultiConvBlock(nn.Module):
|
||||||
assert depth >= 2
|
assert depth >= 2
|
||||||
super(MultiConvBlock, self).__init__()
|
super(MultiConvBlock, self).__init__()
|
||||||
self.noise_scale = nn.Parameter(torch.full((1,), fill_value=.01))
|
self.noise_scale = nn.Parameter(torch.full((1,), fill_value=.01))
|
||||||
self.bnconvs = nn.ModuleList([ConvBnLelu(filters_in, filters_mid, kernel_size, bn=bn)] +
|
self.bnconvs = nn.ModuleList([ConvBnLelu(filters_in, filters_mid, kernel_size, bn=bn, bias=False)] +
|
||||||
[ConvBnLelu(filters_mid, filters_mid, kernel_size, bn=bn) for i in range(depth-2)] +
|
[ConvBnLelu(filters_mid, filters_mid, kernel_size, bn=bn, bias=False) for i in range(depth-2)] +
|
||||||
[ConvBnLelu(filters_mid, filters_out, kernel_size, lelu=False, bn=False)])
|
[ConvBnLelu(filters_mid, filters_out, kernel_size, lelu=False, bn=False, bias=False)])
|
||||||
self.scale = nn.Parameter(torch.full((1,), fill_value=scale_init))
|
self.scale = nn.Parameter(torch.full((1,), fill_value=scale_init))
|
||||||
self.bias = nn.Parameter(torch.zeros(1))
|
self.bias = nn.Parameter(torch.zeros(1))
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user