Add SpineNet & integrate with SRG

New version of SRG uses SpineNet for a switch backbone.
This commit is contained in:
James Betker 2020-07-03 12:07:31 -06:00
parent 3ed7a2b9ab
commit 703dec4472
5 changed files with 444 additions and 83 deletions

View File

@ -4,79 +4,11 @@ from switched_conv import BareConvSwitch, compute_attention_specificity
import torch.nn.functional as F
import functools
from collections import OrderedDict
from models.archs.arch_util import initialize_weights
from models.archs.arch_util import initialize_weights, ConvBnRelu, ConvBnLelu
from models.archs.RRDBNet_arch import ResidualDenseBlock_5C
from models.archs.spinenet_arch import SpineNet
from switched_conv_util import save_attention_to_image
''' Convenience class with Conv->BN->ReLU. Includes weight initialization and auto-padding for standard
kernel sizes. '''
class ConvBnRelu(nn.Module):
def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, relu=True, bn=True, bias=True):
super(ConvBnRelu, self).__init__()
padding_map = {1: 0, 3: 1, 5: 2, 7: 3}
assert kernel_size in padding_map.keys()
self.conv = nn.Conv2d(filters_in, filters_out, kernel_size, stride, padding_map[kernel_size], bias=bias)
if bn:
self.bn = nn.BatchNorm2d(filters_out)
else:
self.bn = None
if relu:
self.relu = nn.ReLU()
else:
self.relu = None
# Init params.
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu' if self.relu else 'linear')
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self, x):
x = self.conv(x)
if self.bn:
x = self.bn(x)
if self.relu:
return self.relu(x)
else:
return x
''' Convenience class with Conv->BN->ReLU. Includes weight initialization and auto-padding for standard
kernel sizes. '''
class ConvBnLelu(nn.Module):
def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, lelu=True, bn=True, bias=True):
super(ConvBnLelu, self).__init__()
padding_map = {1: 0, 3: 1, 5: 2, 7: 3}
assert kernel_size in padding_map.keys()
self.conv = nn.Conv2d(filters_in, filters_out, kernel_size, stride, padding_map[kernel_size], bias=bias)
if bn:
self.bn = nn.BatchNorm2d(filters_out)
else:
self.bn = None
if lelu:
self.lelu = nn.LeakyReLU(negative_slope=.1)
else:
self.lelu = None
# Init params.
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, a=.1, mode='fan_out', nonlinearity='leaky_relu' if self.lelu else 'linear')
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self, x):
x = self.conv(x)
if self.bn:
x = self.bn(x)
if self.lelu:
return self.lelu(x)
else:
return x
class MultiConvBlock(nn.Module):
def __init__(self, filters_in, filters_mid, filters_out, kernel_size, depth, scale_init=1, bn=False):
@ -214,7 +146,7 @@ class ConfigurableSwitchComputer(nn.Module):
m = self.multiplexer(identity)
# Interpolate the multiplexer across the entire shape of the image.
m = F.interpolate(m, size=x.shape[2:], mode='nearest')
m = F.interpolate(m, size=xformed[0].shape[2:], mode='nearest')
outputs, attention = self.switch(xformed, m, True)
outputs = identity + outputs * self.switch_scale
@ -252,6 +184,22 @@ class ConvBasisMultiplexer(nn.Module):
return x
class SpineNetMultiplexer(nn.Module):
def __init__(self, input_channels, transform_count):
super(SpineNetMultiplexer, self).__init__()
self.backbone = SpineNet('49', in_channels=input_channels)
self.rdc1 = ConvBnRelu(256, 128, kernel_size=3, bias=False)
self.rdc2 = ConvBnRelu(128, 64, kernel_size=3, bias=False)
self.rdc3 = ConvBnRelu(64, transform_count, bias=False, bn=False, relu=False)
def forward(self, x):
spine = self.backbone(x)
feat = self.rdc1(spine[0])
feat = self.rdc2(feat)
feat = self.rdc3(feat)
return feat
class ConvBasisMultiplexerReducer(nn.Module):
def __init__(self, input_channels, base_filters, growth, reductions, processing_depth):
super(ConvBasisMultiplexerReducer, self).__init__()
@ -415,6 +363,97 @@ class ConfigurableSwitchedResidualGenerator2(nn.Module):
if step % 50 == 0:
[save_attention_to_image(experiments_path, self.attentions[i], self.transformation_counts[i], step, "a%i" % (i+1,)) for i in range(len(self.switches))]
def get_debug_values(self, step):
temp = self.switches[0].switch.temperature
mean_hists = [compute_attention_specificity(att, 2) for att in self.attentions]
means = [i[0] for i in mean_hists]
hists = [i[1].clone().detach().cpu().flatten() for i in mean_hists]
val = {"switch_temperature": temp}
for i in range(len(means)):
val["switch_%i_specificity" % (i,)] = means[i]
val["switch_%i_histogram" % (i,)] = hists[i]
return val
class Interpolate(nn.Module):
def __init__(self, factor):
super(Interpolate, self).__init__()
self.factor = factor
def forward(self, x):
return F.interpolate(x, scale_factor=self.factor)
class ConfigurableSwitchedResidualGenerator3(nn.Module):
def __init__(self, trans_counts,
trans_kernel_sizes,
trans_layers, transformation_filters, initial_temp=20, final_temperature_step=50000,
heightened_temp_min=1,
heightened_final_step=50000, upsample_factor=1, enable_negative_transforms=False,
add_scalable_noise_to_transforms=False):
super(ConfigurableSwitchedResidualGenerator3, self).__init__()
switches = []
for trans_count, kernel, layers in zip(trans_counts, trans_kernel_sizes, trans_layers):
multiplx_fn = functools.partial(SpineNetMultiplexer, 3)
switches.append(ConfigurableSwitchComputer(base_filters=3, multiplexer_net=multiplx_fn,
pre_transform_block=functools.partial(nn.Sequential,
ConvBnLelu(3, transformation_filters, kernel_size=1, stride=4, bn=False, lelu=False, bias=False),
ResidualDenseBlock_5C(
transformation_filters),
ResidualDenseBlock_5C(
transformation_filters)),
transform_block=functools.partial(nn.Sequential,
ResidualDenseBlock_5C(transformation_filters),
Interpolate(4),
ConvBnLelu(transformation_filters, transformation_filters // 2, kernel_size=3, bias=False, bn=False),
ConvBnLelu(transformation_filters // 2, 3, kernel_size=1, bias=False, bn=False, lelu=False)),
transform_count=trans_count, init_temp=initial_temp,
enable_negative_transforms=enable_negative_transforms,
add_scalable_noise_to_transforms=add_scalable_noise_to_transforms,
init_scalar=.01))
self.switches = nn.ModuleList(switches)
self.transformation_counts = trans_counts
self.init_temperature = initial_temp
self.final_temperature_step = final_temperature_step
self.heightened_temp_min = heightened_temp_min
self.heightened_final_step = heightened_final_step
self.attentions = None
self.upsample_factor = upsample_factor
def forward(self, x):
if self.upsample_factor > 1:
x = F.interpolate(x, scale_factor=self.upsample_factor, mode="nearest")
self.attentions = []
for i, sw in enumerate(self.switches):
x, att = sw.forward(x, True)
self.attentions.append(att)
return x,
def set_temperature(self, temp):
[sw.set_temperature(temp) for sw in self.switches]
def update_for_step(self, step, experiments_path='.'):
if self.attentions:
temp = max(1, int(
self.init_temperature * (self.final_temperature_step - step) / self.final_temperature_step))
if temp == 1 and self.heightened_final_step and self.heightened_final_step != 1:
# Once the temperature passes (1) it enters an inverted curve to match the linear curve from above.
# without this, the attention specificity "spikes" incredibly fast in the last few iterations.
h_steps_total = self.heightened_final_step - self.final_temperature_step
h_steps_current = min(step - self.final_temperature_step, h_steps_total)
# The "gap" will represent the steps that need to be traveled as a linear function.
h_gap = 1 / self.heightened_temp_min
temp = h_gap * h_steps_current / h_steps_total
# Invert temperature to represent reality on this side of the curve
temp = 1 / temp
self.set_temperature(temp)
if step % 50 == 0:
[save_attention_to_image(experiments_path, self.attentions[i], self.transformation_counts[i], step,
"a%i" % (i + 1,)) for i in range(len(self.switches))]
def get_debug_values(self, step):
temp = self.switches[0].switch.temperature
mean_hists = [compute_attention_specificity(att, 2) for att in self.attentions]

View File

@ -0,0 +1,319 @@
# Taken and modified from https://github.com/lucifer443/SpineNet-Pytorch/blob/master/mmdet/models/backbones/spinenet.py
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.init import kaiming_normal
from torchvision.models.resnet import BasicBlock, Bottleneck
from torch.nn.modules.batchnorm import _BatchNorm
from models.archs.arch_util import ConvBnRelu
def constant_init(module, val, bias=0):
if hasattr(module, 'weight') and module.weight is not None:
nn.init.constant_(module.weight, val)
if hasattr(module, 'bias') and module.bias is not None:
nn.init.constant_(module.bias, bias)
def kaiming_init(module,
a=0,
mode='fan_out',
nonlinearity='relu',
bias=0,
distribution='normal'):
assert distribution in ['uniform', 'normal']
if distribution == 'uniform':
nn.init.kaiming_uniform_(
module.weight, a=a, mode=mode, nonlinearity=nonlinearity)
else:
nn.init.kaiming_normal_(
module.weight, a=a, mode=mode, nonlinearity=nonlinearity)
if hasattr(module, 'bias') and module.bias is not None:
nn.init.constant_(module.bias, bias)
FILTER_SIZE_MAP = {
1: 32,
2: 64,
3: 128,
4: 256,
5: 256,
6: 256,
7: 256,
}
def make_res_layer(block,
inplanes,
planes,
blocks,
stride=1,
dilation=1):
downsample = None
if stride != 1 or inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(
inplanes,
planes * block.expansion,
kernel_size=1,
stride=stride,
bias=False),
nn.BatchNorm2d(planes * block.expansion),
)
layers = []
layers.append(
block(
inplanes=inplanes,
planes=planes,
stride=stride,
dilation=dilation,
downsample=downsample))
inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(
block(
inplanes=inplanes,
planes=planes,
stride=1,
dilation=dilation))
return nn.Sequential(*layers)
# The fixed SpineNet architecture discovered by NAS.
# Each element represents a specification of a building block:
# (block_level, block_fn, (input_offset0, input_offset1), is_output).
SPINENET_BLOCK_SPECS = [
(2, Bottleneck, (None, None), False), # init block
(2, Bottleneck, (None, None), False), # init block
(2, Bottleneck, (0, 1), False),
(4, BasicBlock, (0, 1), False),
(3, Bottleneck, (2, 3), False),
(4, Bottleneck, (2, 4), False),
(6, BasicBlock, (3, 5), False),
(4, Bottleneck, (3, 5), False),
(5, BasicBlock, (6, 7), False),
(7, BasicBlock, (6, 8), False),
(5, Bottleneck, (8, 9), False),
(5, Bottleneck, (8, 10), False),
(4, Bottleneck, (5, 10), True),
(3, Bottleneck, (4, 10), True),
(5, Bottleneck, (7, 12), True),
(7, Bottleneck, (5, 14), True),
(6, Bottleneck, (12, 14), True),
]
SCALING_MAP = {
'49S': {
'endpoints_num_filters': 128,
'filter_size_scale': 0.65,
'resample_alpha': 0.5,
'block_repeats': 1,
},
'49': {
'endpoints_num_filters': 256,
'filter_size_scale': 1.0,
'resample_alpha': 0.5,
'block_repeats': 1,
},
'96': {
'endpoints_num_filters': 256,
'filter_size_scale': 1.0,
'resample_alpha': 0.5,
'block_repeats': 2,
},
'143': {
'endpoints_num_filters': 256,
'filter_size_scale': 1.0,
'resample_alpha': 1.0,
'block_repeats': 3,
},
'190': {
'endpoints_num_filters': 512,
'filter_size_scale': 1.3,
'resample_alpha': 1.0,
'block_repeats': 4,
},
}
class BlockSpec(object):
"""A container class that specifies the block configuration for SpineNet."""
def __init__(self, level, block_fn, input_offsets, is_output):
self.level = level
self.block_fn = block_fn
self.input_offsets = input_offsets
self.is_output = is_output
def build_block_specs(block_specs=None):
"""Builds the list of BlockSpec objects for SpineNet."""
if not block_specs:
block_specs = SPINENET_BLOCK_SPECS
return [BlockSpec(*b) for b in block_specs]
class Resample(nn.Module):
def __init__(self, in_channels, out_channels, scale, block_type, alpha=1.0):
super(Resample, self).__init__()
self.scale = scale
new_in_channels = int(in_channels * alpha)
if block_type == Bottleneck:
in_channels *= 4
self.squeeze_conv = ConvBnRelu(in_channels, new_in_channels, kernel_size=1)
if scale < 1:
self.downsample_conv = ConvBnRelu(new_in_channels, new_in_channels, kernel_size=3, stride=2)
self.expand_conv = ConvBnRelu(new_in_channels, out_channels, kernel_size=1, relu=False)
def _resize(self, x):
if self.scale == 1:
return x
elif self.scale > 1:
return F.interpolate(x, scale_factor=self.scale, mode='nearest')
else:
x = self.downsample_conv(x)
if self.scale < 0.5:
new_kernel_size = 3 if self.scale >= 0.25 else 5
x = F.max_pool2d(x, kernel_size=new_kernel_size, stride=int(0.5/self.scale), padding=new_kernel_size//2)
return x
def forward(self, inputs):
feat = self.squeeze_conv(inputs)
feat = self._resize(feat)
feat = self.expand_conv(feat)
return feat
class Merge(nn.Module):
"""Merge two input tensors"""
def __init__(self, block_spec, alpha, filter_size_scale):
super(Merge, self).__init__()
out_channels = int(FILTER_SIZE_MAP[block_spec.level] * filter_size_scale)
if block_spec.block_fn == Bottleneck:
out_channels *= 4
self.block = block_spec.block_fn
self.resample_ops = nn.ModuleList()
for spec_idx in block_spec.input_offsets:
spec = BlockSpec(*SPINENET_BLOCK_SPECS[spec_idx])
in_channels = int(FILTER_SIZE_MAP[spec.level] * filter_size_scale)
scale = 2**(spec.level - block_spec.level)
self.resample_ops.append(
Resample(in_channels, out_channels, scale, spec.block_fn, alpha)
)
def forward(self, inputs):
assert len(inputs) == len(self.resample_ops)
parent0_feat = self.resample_ops[0](inputs[0])
parent1_feat = self.resample_ops[1](inputs[1])
target_feat = parent0_feat + parent1_feat
return target_feat
class SpineNet(nn.Module):
"""Class to build SpineNet backbone"""
def __init__(self,
arch,
in_channels=3,
output_level=[3, 4, 5, 6, 7],
zero_init_residual=True):
super(SpineNet, self).__init__()
self._block_specs = build_block_specs()[2:]
self._endpoints_num_filters = SCALING_MAP[arch]['endpoints_num_filters']
self._resample_alpha = SCALING_MAP[arch]['resample_alpha']
self._block_repeats = SCALING_MAP[arch]['block_repeats']
self._filter_size_scale = SCALING_MAP[arch]['filter_size_scale']
self._init_block_fn = Bottleneck
self._num_init_blocks = 2
self.zero_init_residual = zero_init_residual
assert min(output_level) > 2 and max(output_level) < 8, "Output level out of range"
self.output_level = output_level
self._make_stem_layer(in_channels)
self._make_scale_permuted_network()
self._make_endpoints()
def _make_stem_layer(self, in_channels):
"""Build the stem network."""
# Build the first conv and maxpooling layers.
self.conv1 = ConvBnRelu(
in_channels,
64,
kernel_size=7,
stride=2) # Original paper had stride=2 and a maxpool after.
# Build the initial level 2 blocks.
self.init_block1 = make_res_layer(
self._init_block_fn,
64,
int(FILTER_SIZE_MAP[2] * self._filter_size_scale),
self._block_repeats)
self.init_block2 = make_res_layer(
self._init_block_fn,
int(FILTER_SIZE_MAP[2] * self._filter_size_scale) * 4,
int(FILTER_SIZE_MAP[2] * self._filter_size_scale),
self._block_repeats)
def _make_endpoints(self):
self.endpoint_convs = nn.ModuleDict()
for block_spec in self._block_specs:
if block_spec.is_output:
in_channels = int(FILTER_SIZE_MAP[block_spec.level]*self._filter_size_scale) * 4
self.endpoint_convs[str(block_spec.level)] = ConvBnRelu(in_channels,
self._endpoints_num_filters,
kernel_size=1,
relu=False)
def _make_scale_permuted_network(self):
self.merge_ops = nn.ModuleList()
self.scale_permuted_blocks = nn.ModuleList()
for spec in self._block_specs:
self.merge_ops.append(
Merge(spec, self._resample_alpha, self._filter_size_scale)
)
channels = int(FILTER_SIZE_MAP[spec.level] * self._filter_size_scale)
in_channels = channels * 4 if spec.block_fn == Bottleneck else channels
self.scale_permuted_blocks.append(
make_res_layer(spec.block_fn,
in_channels,
channels,
self._block_repeats)
)
def init_weights(self, pretrained=None):
for m in self.modules():
if isinstance(m, nn.Conv2d):
kaiming_init(m)
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
constant_init(m, 1)
if self.zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
constant_init(m.norm3, 0)
elif isinstance(m, BasicBlock):
constant_init(m.norm2, 0)
def forward(self, input):
feat = self.conv1(input)
feat1 = self.init_block1(feat)
feat2 = self.init_block2(feat1)
block_feats = [feat1, feat2]
output_feat = {}
num_outgoing_connections = [0, 0]
for i, spec in enumerate(self._block_specs):
target_feat = self.merge_ops[i]([block_feats[feat_idx] for feat_idx in spec.input_offsets])
# Connect intermediate blocks with outdegree 0 to the output block.
if spec.is_output:
for j, (j_feat, j_connections) in enumerate(
zip(block_feats, num_outgoing_connections)):
if j_connections == 0 and j_feat.shape == target_feat.shape:
target_feat += j_feat
num_outgoing_connections[j] += 1
target_feat = F.relu(target_feat, inplace=True)
target_feat = self.scale_permuted_blocks[i](target_feat)
block_feats.append(target_feat)
num_outgoing_connections.append(0)
for feat_idx in spec.input_offsets:
num_outgoing_connections[feat_idx] += 1
if spec.is_output:
output_feat[spec.level] = target_feat
return [self.endpoint_convs[str(level)](output_feat[level]) for level in self.output_level]

View File

@ -66,6 +66,13 @@ def define_G(opt, net_key='network_G'):
initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'],
heightened_temp_min=opt_net['heightened_temp_min'], heightened_final_step=opt_net['heightened_final_step'],
upsample_factor=scale, add_scalable_noise_to_transforms=opt_net['add_noise'])
elif which_model == "ConfigurableSwitchedResidualGenerator3":
netG = SwitchedGen_arch.ConfigurableSwitchedResidualGenerator3(trans_counts=opt_net['trans_counts'],
trans_kernel_sizes=opt_net['trans_kernel_sizes'], trans_layers=opt_net['trans_layers'],
transformation_filters=opt_net['transformation_filters'],
initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'],
heightened_temp_min=opt_net['heightened_temp_min'], heightened_final_step=opt_net['heightened_final_step'],
upsample_factor=scale, add_scalable_noise_to_transforms=opt_net['add_noise'])
elif which_model == "NestedSwitchGenerator":
netG = ng.NestedSwitchedGenerator(switch_filters=opt_net['switch_filters'],
switch_reductions=opt_net['switch_reductions'],

View File

@ -33,7 +33,7 @@ def init_dist(backend='nccl', **kwargs):
def main():
#### options
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_div2k_feat_resgen2_lr.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_div2k_srg3.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
@ -162,7 +162,7 @@ def main():
current_step = resume_state['iter']
model.resume_training(resume_state) # handle optimizers and schedulers
else:
current_step = -1
current_step = 0
start_epoch = 0
#### training

View File

@ -4,7 +4,7 @@ import models.archs.SwitchedResidualGenerator_arch as srg
import models.archs.NestedSwitchGenerator as nsg
import functools
blacklisted_modules = [nn.Conv2d, nn.ReLU, nn.LeakyReLU, nn.BatchNorm2d, nn.Softmax]
blacklisted_modules = [nn.Conv2d, nn.ReLU, nn.LeakyReLU, nn.BatchNorm2d, nn.Softmax, srg.Interpolate]
def install_forward_trace_hooks(module, id="base"):
if type(module) in blacklisted_modules:
return
@ -96,15 +96,11 @@ if __name__ == "__main__":
torch.randn(1, 3, 64, 64),
device='cuda')
'''
test_stability(functools.partial(srg.ConfigurableSwitchedResidualGenerator2,
switch_filters=[16,16,16,16,16],
switch_growths=[32,32,32,32,32],
switch_reductions=[1,1,1,1,1],
switch_processing_layers=[5,5,5,5,5],
trans_counts=[8,8,8,8,8],
trans_kernel_sizes=[3,3,3,3,3],
trans_layers=[3,3,3,3,3],
test_stability(functools.partial(srg.ConfigurableSwitchedResidualGenerator3,
trans_counts=[8],
trans_kernel_sizes=[3],
trans_layers=[3],
transformation_filters=64,
initial_temp=10),
torch.randn(1, 3, 64, 64),
torch.randn(1, 3, 128, 128),
device='cuda')