Model arch cleanup

This commit is contained in:
James Betker 2020-09-27 11:18:45 -06:00
parent 7dff802144
commit 4d29b7729e
5 changed files with 0 additions and 648 deletions

View File

@ -1,80 +0,0 @@
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
def conv3x3(in_planes, out_planes, stride=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=1, bias=False)
def conv5x5(in_planes, out_planes, stride=1):
"""5x5 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=5, stride=stride,
padding=2, bias=False)
def conv7x7(in_planes, out_planes, stride=1):
"""7x7 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=7, stride=stride,
padding=3, bias=False)
def conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution"""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
class SequenceDistributed(nn.Module):
def __init__(self, module, batch_first=False):
super(SequenceDistributed, self).__init__()
self.module = module
self.batch_first = batch_first
def forward(self, x):
if len(x.size()) <= 2:
return self.module(x)
# Squash samples and timesteps into a single axis
x_reshape = x.contiguous().view(-1, x.size(-1)) # (samples * timesteps, input_size)
y = self.module(x_reshape)
# We have to reshape Y
if self.batch_first:
y = y.contiguous().view(x.size(0), -1, y.size(-1)) # (samples, timesteps, output_size)
else:
y = y.view(-1, x.size(1), y.size(-1)) # (timesteps, samples, output_size)
return y
# Input into this block is of shape (sequence, filters, width, height)
# Output is (attention_hidden_size, width, height)
class ConvAttentionBlock(nn.Module):
def __init__(self, planes, attention_hidden_size=8, query_conv=conv1x1, key_conv=conv1x1, value_conv=conv1x1):
super(ConvAttentionBlock, self).__init__()
self.query_conv_dist = SequenceDistributed(query_conv(planes, attention_hidden_size))
self.key_conv_dist = SequenceDistributed(key_conv(planes, attention_hidden_size))
self.value_conv_dist = value_conv(planes, attention_hidden_size)
self.hidden_size = attention_hidden_size
def forward(self, x):
# All values come out of this with the shape (batch, sequence, hidden, width, height)
query = self.query_conv_dist(x)
key = self.key_conv_dist(x)
value = self.value_conv_dist(x)
# Permute to (batch, width, height, sequence, hidden)
query = query.permute(0, 3, 4, 1, 2)
key = key.permute(0, 3, 4, 1, 2)
value = value.permute(0, 3, 4, 1, 2)
# Perform attention operation.
scores = torch.matmul(query, key.transpose(-2, -1)) / torch.sqrt(self.hidden_size)
scores = torch.softmax(scores, dim=-1)
result = torch.matmul(scores, value)
# Collapse out the sequence dim.
result = torch.sum(result, dim=-2)
# Permute back to (batch, hidden, width, height)
result = result.permute(0, 3, 1, 2)
return result

View File

@ -1,134 +0,0 @@
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
def conv3x3(in_planes, out_planes, stride=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=1, bias=False)
def conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution"""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
class FixupBasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(FixupBasicBlock, self).__init__()
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
self.bias1a = nn.Parameter(torch.zeros(1))
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = nn.BatchNorm2d(planes, affine=True)
self.bias1b = nn.Parameter(torch.zeros(1))
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
self.bias2a = nn.Parameter(torch.zeros(1))
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2d(planes, affine=True)
self.scale = nn.Parameter(torch.ones(1))
self.bias2b = nn.Parameter(torch.zeros(1))
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x + self.bias1a)
out = self.lrelu(out + self.bias1b)
out = self.conv2(out + self.bias2a)
out = out * self.scale + self.bias2b
if self.downsample is not None:
identity = self.downsample(x + self.bias1a)
out += identity
out = self.lrelu(out)
return out
class FixupResNet(nn.Module):
def __init__(self, block, num_filters, layers, num_classes=1000):
super(FixupResNet, self).__init__()
self.num_layers = sum(layers)
self.bias1 = nn.Parameter(torch.zeros(1))
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
self.pixel_shuffle = nn.PixelShuffle(2)
# 4 input channels, including the noise.
self.conv1 = nn.Conv2d(4, num_filters, kernel_size=7, stride=2, padding=3,
bias=False)
self.inplanes = num_filters
self.down_layer1 = self._make_layer(block, num_filters, layers[0])
self.down_layer2 = self._make_layer(block, num_filters, layers[1], stride=2)
self.down_layer3 = self._make_layer(block, num_filters * 4, layers[2], stride=2)
self.down_layer4 = self._make_layer(block, num_filters * 16, layers[3], stride=2)
self.inplanes = num_filters * 4
self.up_layer1 = self._make_layer(block, num_filters * 4, layers[4], stride=1)
self.inplanes = num_filters
self.up_layer2 = self._make_layer(block, num_filters, layers[5], stride=1)
self.defilter = nn.Conv2d(num_filters, 3, kernel_size=5, stride=1, padding=2, bias=False)
for m in self.modules():
if isinstance(m, FixupBasicBlock):
nn.init.normal_(m.conv1.weight, mean=0, std=np.sqrt(2 / (m.conv1.weight.shape[0] * np.prod(m.conv1.weight.shape[2:]))) * self.num_layers ** (-0.5))
nn.init.constant_(m.conv2.weight, 0)
if m.downsample is not None:
nn.init.normal_(m.downsample.weight, mean=0, std=np.sqrt(2 / (m.downsample.weight.shape[0] * np.prod(m.downsample.weight.shape[2:]))))
elif isinstance(m, nn.Linear):
nn.init.constant_(m.weight, 0)
nn.init.constant_(m.bias, 0)
def _make_layer(self, block, planes, blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = conv1x1(self.inplanes, planes * block.expansion, stride)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes))
return nn.Sequential(*layers)
def forward(self, x):
skip = x
# Noise has the same shape as the input with only one channel.
rand_feature = torch.randn((x.shape[0], 1) + x.shape[2:], device=x.device, dtype=x.dtype)
x = torch.cat([x, rand_feature], dim=1)
x = self.conv1(x)
x = self.lrelu(x + self.bias1)
x = self.down_layer1(x)
x = self.down_layer2(x)
x = self.down_layer3(x)
x = self.down_layer4(x)
x = self.pixel_shuffle(x)
x = self.up_layer1(x)
x = self.pixel_shuffle(x)
x = self.up_layer2(x)
x = self.defilter(x)
base = F.interpolate(skip, scale_factor=.25, mode='bilinear', align_corners=False)
return x + base
def fixup_resnet34(num_filters, **kwargs):
"""Constructs a Fixup-ResNet-34 model.
"""
model = FixupResNet(FixupBasicBlock, num_filters, [3, 4, 6, 3, 2, 2], **kwargs)
return model

View File

@ -1,122 +0,0 @@
import functools
import torch.nn as nn
import torch.nn.functional as F
import models.archs.arch_util as arch_util
import torch
class ReduceAnnealer(nn.Module):
'''
Reduces an image dimensionality by half and performs a specified number of residual blocks on it before
`annealing` the filter count to the same as the input filter count.
To reduce depth, accepts an interpolated "trunk" input which is summed with the output of the RA block before
returning.
Returns a tuple in the forward pass. The first return is the annealed output. The second is the output before
annealing (e.g. number_filters=input*4) which can be be used for upsampling.
'''
def __init__(self, number_filters, residual_blocks):
super(ReduceAnnealer, self).__init__()
self.reducer = nn.Conv2d(number_filters, number_filters*4, 3, stride=2, padding=1, bias=True)
self.res_trunk = arch_util.make_layer(functools.partial(arch_util.ResidualBlock, nf=number_filters*4), residual_blocks)
self.annealer = nn.Conv2d(number_filters*4, number_filters, 3, stride=1, padding=1, bias=True)
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
arch_util.initialize_weights([self.reducer, self.annealer], .1)
self.bn_reduce = nn.BatchNorm2d(number_filters*4, affine=True)
self.bn_anneal = nn.BatchNorm2d(number_filters*4, affine=True)
def forward(self, x, interpolated_trunk):
out = self.lrelu(self.bn_reduce(self.reducer(x)))
out = self.lrelu(self.bn_anneal(self.res_trunk(out)))
annealed = self.lrelu(self.annealer(out)) + interpolated_trunk
return annealed, out
class Assembler(nn.Module):
'''
Upsamples a given input using PixelShuffle. Then upsamples this input further and adds in a residual raw input from
a corresponding upstream ReduceAnnealer. Finally performs processing using ResNet blocks.
'''
def __init__(self, number_filters, residual_blocks):
super(Assembler, self).__init__()
self.pixel_shuffle = nn.PixelShuffle(2)
self.upsampler = nn.Conv2d(number_filters, number_filters*4, 3, stride=1, padding=1, bias=True)
self.res_trunk = arch_util.make_layer(functools.partial(arch_util.ResidualBlock, nf=number_filters*4), residual_blocks)
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
self.bn = nn.BatchNorm2d(number_filters*4, affine=True)
self.bn_up = nn.BatchNorm2d(number_filters*4, affine=True)
def forward(self, input, skip_raw):
out = self.pixel_shuffle(input)
out = self.bn_up(self.upsampler(out)) + skip_raw
out = self.lrelu(self.bn(self.res_trunk(out)))
return out
class FlatProcessorNet(nn.Module):
'''
Specialized network that tries to perform a near-equal amount of processing on each of 5 downsampling steps. Image
is then upsampled to a specified size with a similarly flat amount of processing.
This network automatically applies a noise vector on the inputs to provide entropy for processing.
'''
def __init__(self, in_nc=3, out_nc=3, nf=64, reduce_anneal_blocks=4, assembler_blocks=2, downscale=4):
super(FlatProcessorNet, self).__init__()
assert downscale in [1, 2, 4], "Requested downscale not supported; %i" % (downscale, )
self.downscale = downscale
# We will always apply a noise channel to the inputs, account for that here.
in_nc += 1
# We need two layers to move the image into the filter space in which we will perform most of the work.
self.conv_first = nn.Conv2d(in_nc, nf, 3, stride=1, padding=1, bias=True)
self.conv_last = nn.Conv2d(nf*4, out_nc, 3, stride=1, padding=1, bias=True)
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
# Torch modules need to have all submodules as explicit class members. So make those, then add them into an
# array for easier logic in forward().
self.ra1 = ReduceAnnealer(nf, reduce_anneal_blocks)
self.ra2 = ReduceAnnealer(nf, reduce_anneal_blocks)
self.ra3 = ReduceAnnealer(nf, reduce_anneal_blocks)
self.ra4 = ReduceAnnealer(nf, reduce_anneal_blocks)
self.ra5 = ReduceAnnealer(nf, reduce_anneal_blocks)
self.reducers = [self.ra1, self.ra2, self.ra3, self.ra4, self.ra5]
# Produce assemblers for all possible downscale variants. Some may not be used.
self.assembler1 = Assembler(nf, assembler_blocks)
self.assembler2 = Assembler(nf, assembler_blocks)
self.assembler3 = Assembler(nf, assembler_blocks)
self.assembler4 = Assembler(nf, assembler_blocks)
self.assemblers = [self.assembler1, self.assembler2, self.assembler3, self.assembler4]
# Initialization
arch_util.initialize_weights([self.conv_first, self.conv_last], .1)
def forward(self, x):
# Noise has the same shape as the input with only one channel.
rand_feature = torch.randn((x.shape[0], 1) + x.shape[2:], device=x.device, dtype=x.dtype)
out = torch.cat([x, rand_feature], dim=1)
out = self.lrelu(self.conv_first(out))
features_trunk = out
raw_values = []
downsamples = 1
for ra in self.reducers:
downsamples *= 2
interpolated = F.interpolate(features_trunk, scale_factor=1/downsamples, mode='bilinear', align_corners=False)
out, raw = ra(out, interpolated)
raw_values.append(raw)
i = -1
out = raw_values[-1]
while downsamples != self.downscale:
out = self.assemblers[i](out, raw_values[i-1])
i -= 1
downsamples = int(downsamples / 2)
out = self.conv_last(out)
basis = x
if downsamples != 1:
basis = F.interpolate(x, scale_factor=1/downsamples, mode='bilinear', align_corners=False)
return basis + out

View File

@ -1,86 +0,0 @@
import functools
import torch.nn as nn
import torch.nn.functional as F
import models.archs.arch_util as arch_util
import torch
class HighToLowResNet(nn.Module):
''' ResNet that applies a noise channel to the input, then downsamples it four times using strides. Finally, the
input is upsampled to the desired downscale. Currently downscale=1,2,4 is supported.
'''
def __init__(self, in_nc=3, out_nc=3, nf=64, nb=16, downscale=4):
super(HighToLowResNet, self).__init__()
assert downscale in [1, 2, 4], "Requested downscale not supported; %i" % (downscale, )
self.downscale = downscale
# We will always apply a noise channel to the inputs, account for that here.
in_nc += 1
self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
# All sub-modules must be explicit members. Make it so. Then add them to a list.
self.trunk1 = arch_util.make_layer(functools.partial(arch_util.ResidualBlock_noBN, nf=nf), 4)
self.trunk2 = arch_util.make_layer(functools.partial(arch_util.ResidualBlock_noBN, nf=nf*2), 6)
self.trunk3 = arch_util.make_layer(functools.partial(arch_util.ResidualBlock_noBN, nf=nf*4), 12)
self.trunk4 = arch_util.make_layer(functools.partial(arch_util.ResidualBlock_noBN, nf=nf*8), 12)
self.trunks = [self.trunk1, self.trunk2, self.trunk3, self.trunk4]
self.trunkshapes = [4, 6, 12, 12]
self.r1 = nn.Conv2d(nf, nf*2, 3, stride=2, padding=1, bias=True)
self.r2 = nn.Conv2d(nf*2, nf*4, 3, stride=2, padding=1, bias=True)
self.r3 = nn.Conv2d(nf*4, nf*8, 3, stride=2, padding=1, bias=True)
self.reducers = [self.r1, self.r2, self.r3]
self.pixel_shuffle = nn.PixelShuffle(2)
self.a1 = nn.Conv2d(nf*2, nf*4, 3, stride=1, padding=1, bias=True)
self.a2 = nn.Conv2d(nf, nf*4, 3, stride=1, padding=1, bias=True)
self.a3 = nn.Conv2d(nf, nf, 3, stride=1, padding=1, bias=True)
self.assemblers = [self.a1, self.a2, self.a3]
if self.downscale == 1:
nf_last = nf
elif self.downscale == 2:
nf_last = nf * 4
elif self.downscale == 4:
nf_last = nf * 4
self.conv_last = nn.Conv2d(nf_last, out_nc, 3, stride=1, padding=1, bias=True)
# activation function
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
# initialization
arch_util.initialize_weights([self.conv_first, self.conv_last] + self.reducers + self.assemblers,
.1)
def forward(self, x):
# Noise has the same shape as the input with only one channel.
rand_feature = torch.randn((x.shape[0], 1) + x.shape[2:], device=x.device, dtype=x.dtype)
out = torch.cat([x, rand_feature], dim=1)
out = self.lrelu(self.conv_first(out))
skips = []
for i in range(4):
skips.append(out)
out = self.trunks[i](out)
if i < 3:
out = self.lrelu(self.reducers[i](out))
target_width = x.shape[-1] / self.downscale
i = 0
while out.shape[-1] != target_width:
out = self.pixel_shuffle(out)
out = self.lrelu(self.assemblers[i](out))
out = out + skips[-i-2]
i += 1
# TODO: Figure out where this magic number '12' comes from and fix it.
out = 12 * self.conv_last(out)
if self.downscale == 1:
base = x
else:
base = F.interpolate(x, scale_factor=1/self.downscale, mode='bilinear', align_corners=False)
return out + base

View File

@ -1,226 +0,0 @@
import torch
from torch import nn
from models.archs.arch_util import ConvBnLelu, ConvBnRelu, MultiConvBlock
from switched_conv import BareConvSwitch, compute_attention_specificity
from switched_conv_util import save_attention_to_image
from functools import partial
import torch.nn.functional as F
from collections import OrderedDict
class Switch(nn.Module):
def __init__(self, transform_block, transform_count, init_temp=20, pass_chain_forward=False, add_scalable_noise_to_transforms=False):
super(Switch, self).__init__()
self.transforms = nn.ModuleList([transform_block() for _ in range(transform_count)])
self.add_noise = add_scalable_noise_to_transforms
self.pass_chain_forward = pass_chain_forward
# And the switch itself, including learned scalars
self.switch = BareConvSwitch(initial_temperature=init_temp)
self.scale = nn.Parameter(torch.ones(1))
self.bias = nn.Parameter(torch.zeros(1))
# x is the input fed to the transform blocks.
# m is the output of the multiplexer which will be used to select from those transform blocks.
# chain is a chain of shared processing outputs used by the individual transforms.
def forward(self, x, m, chain):
if self.pass_chain_forward:
pcf = [t(x, chain) for t in self.transforms]
xformed = [o[0] for o in pcf]
atts = [o[1] for o in pcf]
else:
if self.add_noise:
rand_feature = torch.randn_like(x)
xformed = [t(x, rand_feature) for t in self.transforms]
else:
xformed = [t(x) for t in self.transforms]
# Interpolate the multiplexer across the entire shape of the image.
m = F.interpolate(m, size=x.shape[2:], mode='nearest')
outputs, attention = self.switch(xformed, m, True)
outputs = outputs * self.scale + self.bias
if self.pass_chain_forward:
# Apply attention weights to collected [atts] and return the aggregate.
atts = torch.stack(atts, dim=3)
attention = atts * attention.unsqueeze(dim=-1)
attention = torch.flatten(attention, 3)
return outputs, attention
def set_temperature(self, temp):
self.switch.set_attention_temperature(temp)
if self.pass_chain_forward:
[t.set_temperature(temp) for t in self.transforms]
# Convolutional image processing block that optionally reduces image size by a factor of 2 using stride and performs a
# series of conv blocks on it
class Processor(nn.Module):
def __init__(self, base_filters, processing_depth, reduce=False):
super(Processor, self).__init__()
self.output_filter_count = base_filters * (2 if reduce else 1)
self.pre = ConvBnRelu(base_filters, base_filters, kernel_size=3, bias=True)
self.initial = ConvBnRelu(base_filters, self.output_filter_count, kernel_size=1, stride=2 if reduce else 1, bias=False)
self.blocks = nn.Sequential(OrderedDict(
[(str(i), ConvBnRelu(self.output_filter_count, self.output_filter_count, kernel_size=3, bias=False)) for i in range(processing_depth)]))
def forward(self, x):
x = self.pre(x)
x = self.initial(x)
x = self.blocks(x)
return (x - .39) / .58
# Convolutional image processing block that constricts an input image with a large number of filters to a small number
# of filters over a fixed number of layers.
class Constrictor(nn.Module):
def __init__(self, filters, output_filters):
super(Constrictor, self).__init__()
assert(filters > output_filters)
gap = filters - output_filters
gap_div_4 = int(gap / 4)
self.cbl1 = ConvBnRelu(filters, filters - (gap_div_4 * 2), kernel_size=1, norm=True, bias=True)
self.cbl2 = ConvBnRelu(filters - (gap_div_4 * 2), filters - (gap_div_4 * 3), kernel_size=1, norm=True, bias=False)
self.cbl3 = ConvBnRelu(filters - (gap_div_4 * 3), output_filters, kernel_size=1, activation=False, norm=False, bias=False)
def forward(self, x):
x = self.cbl1(x)
x = self.cbl2(x)
x = self.cbl3(x)
return x / 2.67
class RecursiveSwitchedTransform(nn.Module):
def __init__(self, transform_filters, filters_count_list, nesting_depth, transforms_at_leaf,
trans_kernel_size, trans_num_layers, trans_scale_init=1, initial_temp=20, add_scalable_noise_to_transforms=False):
super(RecursiveSwitchedTransform, self).__init__()
self.depth = nesting_depth
at_leaf = (self.depth == 0)
if at_leaf:
transform = partial(MultiConvBlock, transform_filters, transform_filters, transform_filters, kernel_size=trans_kernel_size, depth=trans_num_layers, scale_init=trans_scale_init)
else:
transform = partial(RecursiveSwitchedTransform, transform_filters, filters_count_list,
nesting_depth - 1, transforms_at_leaf, trans_kernel_size, trans_num_layers, trans_scale_init, initial_temp, add_scalable_noise_to_transforms)
selection_breadth = transforms_at_leaf if at_leaf else 2
self.switch = Switch(transform, selection_breadth, initial_temp, pass_chain_forward=not at_leaf, add_scalable_noise_to_transforms=add_scalable_noise_to_transforms)
self.multiplexer = Constrictor(filters_count_list[self.depth], selection_breadth)
def forward(self, x, processing_trunk_chain):
proc_out = processing_trunk_chain[self.depth]
m = self.multiplexer(proc_out)
return self.switch(x, m, processing_trunk_chain)
def set_temperature(self, temp):
self.switch.set_temperature(temp)
class NestedSwitchComputer(nn.Module):
def __init__(self, transform_filters, switch_base_filters, num_switch_processing_layers, nesting_depth, transforms_at_leaf,
trans_kernel_size, trans_num_layers, trans_scale_init, initial_temp=20, add_scalable_noise_to_transforms=False):
super(NestedSwitchComputer, self).__init__()
processing_trunk = []
filters = []
current_filters = switch_base_filters
reduce = False # Don't reduce the first layer, but reduce after that.
for _ in range(nesting_depth):
processing_trunk.append(Processor(current_filters, num_switch_processing_layers, reduce=reduce))
current_filters = processing_trunk[-1].output_filter_count
filters.append(current_filters)
reduce = True
self.multiplexer_init_conv = ConvBnLelu(transform_filters, switch_base_filters, kernel_size=7, activation=False, norm=False)
self.processing_trunk = nn.ModuleList(processing_trunk)
self.switch = RecursiveSwitchedTransform(transform_filters, filters, nesting_depth-1, transforms_at_leaf, trans_kernel_size, trans_num_layers-1, trans_scale_init, initial_temp=initial_temp, add_scalable_noise_to_transforms=add_scalable_noise_to_transforms)
self.anneal = ConvBnLelu(transform_filters, transform_filters, kernel_size=1, norm=False)
def forward(self, x):
feed_forward = x
trunk = []
trunk_input = self.multiplexer_init_conv(x)
for m in self.processing_trunk:
trunk_input = (m(trunk_input) - 3.3) / 12.5
trunk.append(trunk_input)
self.trunk = trunk[-1]
x, att = self.switch(x, trunk)
x = x + feed_forward
return feed_forward + self.anneal(x) / .86, att
def set_temperature(self, temp):
self.switch.set_temperature(temp)
class NestedSwitchedGenerator(nn.Module):
def __init__(self, switch_filters, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes,
trans_layers, transformation_filters, initial_temp=20, final_temperature_step=50000, heightened_temp_min=1,
heightened_final_step=50000, upsample_factor=1, add_scalable_noise_to_transforms=False):
super(NestedSwitchedGenerator, self).__init__()
self.initial_conv = ConvBnLelu(3, transformation_filters, kernel_size=7, activation=False, norm=False)
self.proc_conv = ConvBnLelu(transformation_filters, transformation_filters, norm=False)
self.final_conv = ConvBnLelu(transformation_filters, 3, kernel_size=1, activation=False, norm=False)
switches = []
for sw_reduce, sw_proc, trans_count, kernel, layers in zip(switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes, trans_layers):
switches.append(NestedSwitchComputer(transform_filters=transformation_filters, switch_base_filters=switch_filters, num_switch_processing_layers=sw_proc,
nesting_depth=sw_reduce, transforms_at_leaf=trans_count, trans_kernel_size=kernel, trans_num_layers=layers,
trans_scale_init=.2, initial_temp=initial_temp, add_scalable_noise_to_transforms=add_scalable_noise_to_transforms))
self.switches = nn.ModuleList(switches)
self.transformation_counts = trans_counts
self.init_temperature = initial_temp
self.final_temperature_step = final_temperature_step
self.heightened_temp_min = heightened_temp_min
self.heightened_final_step = heightened_final_step
self.attentions = None
self.upsample_factor = upsample_factor
def forward(self, x):
x = self.initial_conv(x) / .2
self.attentions = []
for i, sw in enumerate(self.switches):
x, att = sw(x)
self.attentions.append(att)
if self.upsample_factor > 1:
x = F.interpolate(x, scale_factor=self.upsample_factor, mode="nearest")
x = self.proc_conv(x) / .85
x = self.final_conv(x) / 4.6
return x / 16,
def set_temperature(self, temp):
[sw.set_temperature(temp) for sw in self.switches]
def update_for_step(self, step, experiments_path='.'):
if self.attentions:
temp = max(1, int(self.init_temperature * (self.final_temperature_step - step) / self.final_temperature_step))
if temp == 1 and self.heightened_final_step and self.heightened_final_step != 1:
# Once the temperature passes (1) it enters an inverted curve to match the linear curve from above.
# without this, the attention specificity "spikes" incredibly fast in the last few iterations.
h_steps_total = self.heightened_final_step - self.final_temperature_step
h_steps_current = min(step - self.final_temperature_step, h_steps_total)
# The "gap" will represent the steps that need to be traveled as a linear function.
h_gap = 1 / self.heightened_temp_min
temp = h_gap * h_steps_current / h_steps_total
# Invert temperature to represent reality on this side of the curve
temp = 1 / temp
self.set_temperature(temp)
if step % 50 == 0:
[save_attention_to_image(experiments_path, self.attentions[i], self.transformation_counts[i], step, "a%i" % (i+1,)) for i in range(len(self.switches))]
def get_debug_values(self, step):
temp = self.switches[0].switch.switch.switch.temperature
mean_hists = [compute_attention_specificity(att, 2) for att in self.attentions]
means = [i[0] for i in mean_hists]
hists = [i[1].clone().detach().cpu().flatten() for i in mean_hists]
val = {"switch_temperature": temp}
for i in range(len(means)):
val["switch_%i_specificity" % (i,)] = means[i]
val["switch_%i_histogram" % (i,)] = hists[i]
return val