forked from mrq/DL-Art-School
407224eba1
Got rid of the converged multiplexer bases but kept the configurable architecture. The new multiplexers look a lot like the old one. Took some queues from the transformer architecture: translate image to a higher filter-space and stay there for the duration of the models computation. Also perform convs after each switch to allow the model to anneal issues that arise.
376 lines
19 KiB
Python
376 lines
19 KiB
Python
import torch
|
|
from torch import nn
|
|
from switched_conv import BareConvSwitch, compute_attention_specificity
|
|
import torch.nn.functional as F
|
|
import functools
|
|
from collections import OrderedDict
|
|
from models.archs.arch_util import initialize_weights
|
|
from switched_conv_util import save_attention_to_image
|
|
|
|
|
|
class ConvBnLelu(nn.Module):
|
|
def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, lelu=True, bn=True):
|
|
super(ConvBnLelu, self).__init__()
|
|
padding_map = {1: 0, 3: 1, 5: 2, 7: 3}
|
|
assert kernel_size in padding_map.keys()
|
|
self.conv = nn.Conv2d(filters_in, filters_out, kernel_size, stride, padding_map[kernel_size])
|
|
if bn:
|
|
self.bn = nn.BatchNorm2d(filters_out)
|
|
else:
|
|
self.bn = None
|
|
if lelu:
|
|
self.lelu = nn.LeakyReLU(negative_slope=.1)
|
|
else:
|
|
self.lelu = None
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
if self.bn:
|
|
x = self.bn(x)
|
|
if self.lelu:
|
|
return self.lelu(x)
|
|
else:
|
|
return x
|
|
|
|
|
|
class ResidualBranch(nn.Module):
|
|
def __init__(self, filters_in, filters_mid, filters_out, kernel_size, depth):
|
|
assert depth >= 2
|
|
super(ResidualBranch, self).__init__()
|
|
self.noise_scale = nn.Parameter(torch.full((1,), fill_value=.01))
|
|
self.bnconvs = nn.ModuleList([ConvBnLelu(filters_in, filters_mid, kernel_size, bn=False)] +
|
|
[ConvBnLelu(filters_mid, filters_mid, kernel_size, bn=False) for i in range(depth-2)] +
|
|
[ConvBnLelu(filters_mid, filters_out, kernel_size, lelu=False, bn=False)])
|
|
self.scale = nn.Parameter(torch.ones(1))
|
|
self.bias = nn.Parameter(torch.zeros(1))
|
|
|
|
def forward(self, x, noise=None):
|
|
if noise is not None:
|
|
noise = noise * self.noise_scale
|
|
x = x + noise
|
|
for m in self.bnconvs:
|
|
x = m.forward(x)
|
|
return x * self.scale + self.bias
|
|
|
|
|
|
# VGG-style layer with Conv(stride2)->BN->Activation->Conv->BN->Activation
|
|
# Doubles the input filter count.
|
|
class HalvingProcessingBlock(nn.Module):
|
|
def __init__(self, filters):
|
|
super(HalvingProcessingBlock, self).__init__()
|
|
self.bnconv1 = ConvBnLelu(filters, filters * 2, stride=2, bn=False)
|
|
self.bnconv2 = ConvBnLelu(filters * 2, filters * 2, bn=True)
|
|
|
|
def forward(self, x):
|
|
x = self.bnconv1(x)
|
|
return self.bnconv2(x)
|
|
|
|
|
|
# Creates a nested series of convolutional blocks. Each block processes the input data in-place and adds
|
|
# filter_growth filters. Return is (nn.Sequential, ending_filters)
|
|
def create_sequential_growing_processing_block(filters_init, filter_growth, num_convs):
|
|
convs = []
|
|
current_filters = filters_init
|
|
for i in range(num_convs):
|
|
convs.append(ConvBnLelu(current_filters, current_filters + filter_growth, bn=True))
|
|
current_filters += filter_growth
|
|
return nn.Sequential(*convs), current_filters
|
|
|
|
|
|
class SwitchComputer(nn.Module):
|
|
def __init__(self, channels_in, filters, growth, transform_block, transform_count, reduction_blocks, processing_blocks=0,
|
|
init_temp=20, enable_negative_transforms=False, add_scalable_noise_to_transforms=False):
|
|
super(SwitchComputer, self).__init__()
|
|
self.enable_negative_transforms = enable_negative_transforms
|
|
|
|
self.filter_conv = ConvBnLelu(channels_in, filters)
|
|
self.reduction_blocks = nn.ModuleList([HalvingProcessingBlock(filters * 2 ** i) for i in range(reduction_blocks)])
|
|
final_filters = filters * 2 ** reduction_blocks
|
|
self.processing_blocks, final_filters = create_sequential_growing_processing_block(final_filters, growth, processing_blocks)
|
|
proc_block_filters = max(final_filters // 2, transform_count)
|
|
self.proc_switch_conv = ConvBnLelu(final_filters, proc_block_filters, bn=False)
|
|
tc = transform_count
|
|
if self.enable_negative_transforms:
|
|
tc = transform_count * 2
|
|
self.final_switch_conv = nn.Conv2d(proc_block_filters, tc, 1, 1, 0)
|
|
|
|
self.transforms = nn.ModuleList([transform_block() for _ in range(transform_count)])
|
|
self.add_noise = add_scalable_noise_to_transforms
|
|
|
|
# And the switch itself, including learned scalars
|
|
self.switch = BareConvSwitch(initial_temperature=init_temp)
|
|
self.scale = nn.Parameter(torch.ones(1))
|
|
self.bias = nn.Parameter(torch.zeros(1))
|
|
|
|
def forward(self, x, output_attention_weights=False):
|
|
if self.add_noise:
|
|
rand_feature = torch.randn_like(x)
|
|
xformed = [t.forward(x, rand_feature) for t in self.transforms]
|
|
else:
|
|
xformed = [t.forward(x) for t in self.transforms]
|
|
if self.enable_negative_transforms:
|
|
xformed.extend([-t for t in xformed])
|
|
|
|
multiplexer = self.filter_conv(x)
|
|
for block in self.reduction_blocks:
|
|
multiplexer = block.forward(multiplexer)
|
|
for block in self.processing_blocks:
|
|
multiplexer = block.forward(multiplexer)
|
|
multiplexer = self.proc_switch_conv(multiplexer)
|
|
multiplexer = self.final_switch_conv.forward(multiplexer)
|
|
# Interpolate the multiplexer across the entire shape of the image.
|
|
multiplexer = F.interpolate(multiplexer, size=x.shape[2:], mode='nearest')
|
|
|
|
outputs, attention = self.switch(xformed, multiplexer, True)
|
|
outputs = outputs * self.scale + self.bias
|
|
if output_attention_weights:
|
|
return outputs, attention
|
|
else:
|
|
return outputs
|
|
|
|
def set_temperature(self, temp):
|
|
self.switch.set_attention_temperature(temp)
|
|
|
|
|
|
class ConfigurableSwitchComputer(nn.Module):
|
|
def __init__(self, multiplexer_net, transform_block, transform_count, init_temp=20,
|
|
enable_negative_transforms=False, add_scalable_noise_to_transforms=False):
|
|
super(ConfigurableSwitchComputer, self).__init__()
|
|
self.enable_negative_transforms = enable_negative_transforms
|
|
|
|
tc = transform_count
|
|
if self.enable_negative_transforms:
|
|
tc = transform_count * 2
|
|
self.multiplexer = multiplexer_net(tc)
|
|
|
|
self.transforms = nn.ModuleList([transform_block() for _ in range(transform_count)])
|
|
self.add_noise = add_scalable_noise_to_transforms
|
|
|
|
# And the switch itself, including learned scalars
|
|
self.switch = BareConvSwitch(initial_temperature=init_temp)
|
|
self.scale = nn.Parameter(torch.ones(1))
|
|
self.bias = nn.Parameter(torch.zeros(1))
|
|
|
|
def forward(self, x, output_attention_weights=False):
|
|
if self.add_noise:
|
|
rand_feature = torch.randn_like(x)
|
|
xformed = [t.forward(x, rand_feature) for t in self.transforms]
|
|
else:
|
|
xformed = [t.forward(x) for t in self.transforms]
|
|
if self.enable_negative_transforms:
|
|
xformed.extend([-t for t in xformed])
|
|
|
|
m = self.multiplexer(x)
|
|
# Interpolate the multiplexer across the entire shape of the image.
|
|
m = F.interpolate(m, size=x.shape[2:], mode='nearest')
|
|
|
|
outputs, attention = self.switch(xformed, m, True)
|
|
outputs = outputs * self.scale + self.bias
|
|
if output_attention_weights:
|
|
return outputs, attention
|
|
else:
|
|
return outputs
|
|
|
|
def set_temperature(self, temp):
|
|
self.switch.set_attention_temperature(temp)
|
|
|
|
|
|
class ConvBasisMultiplexer(nn.Module):
|
|
def __init__(self, input_channels, base_filters, growth, reductions, processing_depth, multiplexer_channels, use_bn=True):
|
|
super(ConvBasisMultiplexer, self).__init__()
|
|
self.filter_conv = ConvBnLelu(input_channels, base_filters)
|
|
self.reduction_blocks = nn.Sequential(OrderedDict([('block%i:' % (i,), HalvingProcessingBlock(base_filters * 2 ** i)) for i in range(reductions)]))
|
|
reduction_filters = base_filters * 2 ** reductions
|
|
self.processing_blocks, self.output_filter_count = create_sequential_growing_processing_block(reduction_filters, growth, processing_depth)
|
|
|
|
gap = self.output_filter_count - multiplexer_channels
|
|
self.cbl1 = ConvBnLelu(self.output_filter_count, self.output_filter_count - (gap // 4), bn=use_bn)
|
|
self.cbl2 = ConvBnLelu(self.output_filter_count - (gap // 4), self.output_filter_count - (gap // 2), bn=use_bn)
|
|
self.cbl3 = ConvBnLelu(self.output_filter_count - (gap // 2), multiplexer_channels)
|
|
|
|
def forward(self, x):
|
|
x = self.filter_conv(x)
|
|
x = self.reduction_blocks(x)
|
|
x = self.processing_blocks(x)
|
|
x = self.cbl1(x)
|
|
x = self.cbl2(x)
|
|
x = self.cbl3(x)
|
|
return x
|
|
|
|
|
|
class ConvBasisMultiplexerBase(nn.Module):
|
|
def __init__(self, input_channels, base_filters, growth, reductions, processing_depth):
|
|
super(ConvBasisMultiplexerBase, self).__init__()
|
|
self.filter_conv = ConvBnLelu(input_channels, base_filters)
|
|
self.reduction_blocks = nn.Sequential(OrderedDict([('block%i:' % (i,), HalvingProcessingBlock(base_filters * 2 ** i)) for i in range(reductions)]))
|
|
reduction_filters = base_filters * 2 ** reductions
|
|
self.processing_blocks, self.output_filter_count = create_sequential_growing_processing_block(reduction_filters, growth, processing_depth)
|
|
|
|
def forward(self, x):
|
|
x = self.filter_conv(x)
|
|
x = self.reduction_blocks(x)
|
|
x = self.processing_blocks(x)
|
|
return x
|
|
|
|
|
|
class ConvBasisMultiplexerLeaf(nn.Module):
|
|
def __init__(self, base, filters, multiplexer_channels, use_bn=False):
|
|
super(ConvBasisMultiplexerLeaf, self).__init__()
|
|
assert(filters > multiplexer_channels)
|
|
gap = filters - multiplexer_channels
|
|
assert(gap % 4 == 0)
|
|
self.base = base
|
|
self.cbl1 = ConvBnLelu(filters, filters - (gap // 4), bn=use_bn)
|
|
self.cbl2 = ConvBnLelu(filters - (gap // 4), filters - (gap // 2), bn=use_bn)
|
|
self.cbl3 = ConvBnLelu(filters - (gap // 2), multiplexer_channels)
|
|
|
|
def forward(self, x):
|
|
x = self.base(x)
|
|
x = self.cbl1(x)
|
|
x = self.cbl2(x)
|
|
x = self.cbl3(x)
|
|
return x
|
|
|
|
|
|
class ConfigurableSwitchedResidualGenerator(nn.Module):
|
|
def __init__(self, switch_filters, switch_growths, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes,
|
|
trans_layers, trans_filters_mid, initial_temp=20, final_temperature_step=50000, heightened_temp_min=1,
|
|
heightened_final_step=50000, upsample_factor=1, enable_negative_transforms=False,
|
|
add_scalable_noise_to_transforms=False):
|
|
super(ConfigurableSwitchedResidualGenerator, self).__init__()
|
|
switches = []
|
|
for filters, growth, sw_reduce, sw_proc, trans_count, kernel, layers, mid_filters in zip(switch_filters, switch_growths, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes, trans_layers, trans_filters_mid):
|
|
switches.append(SwitchComputer(3, filters, growth, functools.partial(ResidualBranch, 3, mid_filters, 3, kernel_size=kernel, depth=layers), trans_count, sw_reduce, sw_proc, initial_temp, enable_negative_transforms=enable_negative_transforms, add_scalable_noise_to_transforms=add_scalable_noise_to_transforms))
|
|
initialize_weights(switches, 1)
|
|
# Initialize the transforms with a lesser weight, since they are repeatedly added on to the resultant image.
|
|
initialize_weights([s.transforms for s in switches], .2 / len(switches))
|
|
self.switches = nn.ModuleList(switches)
|
|
self.transformation_counts = trans_counts
|
|
self.init_temperature = initial_temp
|
|
self.final_temperature_step = final_temperature_step
|
|
self.heightened_temp_min = heightened_temp_min
|
|
self.heightened_final_step = heightened_final_step
|
|
self.attentions = None
|
|
self.upsample_factor = upsample_factor
|
|
|
|
def forward(self, x):
|
|
# This network is entirely a "repair" network and operates on full-resolution images. Upsample first if that
|
|
# is called for, then repair.
|
|
if self.upsample_factor > 1:
|
|
x = F.interpolate(x, scale_factor=self.upsample_factor, mode="nearest")
|
|
|
|
self.attentions = []
|
|
for i, sw in enumerate(self.switches):
|
|
sw_out, att = sw.forward(x, True)
|
|
x = x + sw_out
|
|
self.attentions.append(att)
|
|
return x,
|
|
|
|
def set_temperature(self, temp):
|
|
[sw.set_temperature(temp) for sw in self.switches]
|
|
|
|
def update_for_step(self, step, experiments_path='.'):
|
|
if self.attentions:
|
|
temp = max(1, int(self.init_temperature * (self.final_temperature_step - step) / self.final_temperature_step))
|
|
if temp == 1 and self.heightened_final_step and self.heightened_final_step != 1:
|
|
# Once the temperature passes (1) it enters an inverted curve to match the linear curve from above.
|
|
# without this, the attention specificity "spikes" incredibly fast in the last few iterations.
|
|
h_steps_total = self.heightened_final_step - self.final_temperature_step
|
|
h_steps_current = min(step - self.final_temperature_step, h_steps_total)
|
|
# The "gap" will represent the steps that need to be traveled as a linear function.
|
|
h_gap = 1 / self.heightened_temp_min
|
|
temp = h_gap * h_steps_current / h_steps_total
|
|
# Invert temperature to represent reality on this side of the curve
|
|
temp = 1 / temp
|
|
self.set_temperature(temp)
|
|
if step % 50 == 0:
|
|
[save_attention_to_image(experiments_path, self.attentions[i], self.transformation_counts[i], step, "a%i" % (i+1,)) for i in range(len(self.switches))]
|
|
|
|
def get_debug_values(self, step):
|
|
temp = self.switches[0].switch.temperature
|
|
mean_hists = [compute_attention_specificity(att, 2) for att in self.attentions]
|
|
means = [i[0] for i in mean_hists]
|
|
hists = [i[1].clone().detach().cpu().flatten() for i in mean_hists]
|
|
val = {"switch_temperature": temp}
|
|
for i in range(len(means)):
|
|
val["switch_%i_specificity" % (i,)] = means[i]
|
|
val["switch_%i_histogram" % (i,)] = hists[i]
|
|
return val
|
|
|
|
|
|
class ConfigurableSwitchedResidualGenerator2(nn.Module):
|
|
def __init__(self, switch_filters, switch_growths, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes,
|
|
trans_layers, transformation_filters, initial_temp=20, final_temperature_step=50000, heightened_temp_min=1,
|
|
heightened_final_step=50000, upsample_factor=1, enable_negative_transforms=False,
|
|
add_scalable_noise_to_transforms=False):
|
|
super(ConfigurableSwitchedResidualGenerator2, self).__init__()
|
|
switches = []
|
|
post_switch_proc = []
|
|
self.initial_conv = ConvBnLelu(3, transformation_filters, bn=False)
|
|
self.final_conv = ConvBnLelu(transformation_filters, 3, bn=False)
|
|
for filters, growth, sw_reduce, sw_proc, trans_count, kernel, layers in zip(switch_filters, switch_growths, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes, trans_layers):
|
|
multiplx_fn = functools.partial(ConvBasisMultiplexer, transformation_filters, filters, growth, sw_reduce, sw_proc, trans_count)
|
|
switches.append(ConfigurableSwitchComputer(multiplx_fn, functools.partial(ResidualBranch, transformation_filters, transformation_filters, transformation_filters, kernel_size=kernel, depth=layers), trans_count, initial_temp, enable_negative_transforms=enable_negative_transforms, add_scalable_noise_to_transforms=add_scalable_noise_to_transforms))
|
|
post_switch_proc.append(ConvBnLelu(transformation_filters, transformation_filters, bn=False))
|
|
initialize_weights(switches, 1)
|
|
# Initialize the transforms with a lesser weight, since they are repeatedly added on to the resultant image.
|
|
initialize_weights([s.transforms for s in switches], .2 / len(switches))
|
|
self.switches = nn.ModuleList(switches)
|
|
initialize_weights([p for p in post_switch_proc], .01)
|
|
self.post_switch_convs = nn.ModuleList(post_switch_proc)
|
|
self.transformation_counts = trans_counts
|
|
self.init_temperature = initial_temp
|
|
self.final_temperature_step = final_temperature_step
|
|
self.heightened_temp_min = heightened_temp_min
|
|
self.heightened_final_step = heightened_final_step
|
|
self.attentions = None
|
|
self.upsample_factor = upsample_factor
|
|
|
|
def forward(self, x):
|
|
# This network is entirely a "repair" network and operates on full-resolution images. Upsample first if that
|
|
# is called for, then repair.
|
|
if self.upsample_factor > 1:
|
|
x = F.interpolate(x, scale_factor=self.upsample_factor, mode="nearest")
|
|
|
|
x = self.initial_conv(x)
|
|
|
|
self.attentions = []
|
|
for i, (sw, conv) in enumerate(zip(self.switches, self.post_switch_convs)):
|
|
sw_out, att = sw.forward(x, True)
|
|
self.attentions.append(att)
|
|
x = x + sw_out
|
|
x = x + conv(x)
|
|
|
|
x = self.final_conv(x)
|
|
return x,
|
|
|
|
def set_temperature(self, temp):
|
|
[sw.set_temperature(temp) for sw in self.switches]
|
|
|
|
def update_for_step(self, step, experiments_path='.'):
|
|
if self.attentions:
|
|
temp = max(1, int(self.init_temperature * (self.final_temperature_step - step) / self.final_temperature_step))
|
|
if temp == 1 and self.heightened_final_step and self.heightened_final_step != 1:
|
|
# Once the temperature passes (1) it enters an inverted curve to match the linear curve from above.
|
|
# without this, the attention specificity "spikes" incredibly fast in the last few iterations.
|
|
h_steps_total = self.heightened_final_step - self.final_temperature_step
|
|
h_steps_current = min(step - self.final_temperature_step, h_steps_total)
|
|
# The "gap" will represent the steps that need to be traveled as a linear function.
|
|
h_gap = 1 / self.heightened_temp_min
|
|
temp = h_gap * h_steps_current / h_steps_total
|
|
# Invert temperature to represent reality on this side of the curve
|
|
temp = 1 / temp
|
|
self.set_temperature(temp)
|
|
if step % 50 == 0:
|
|
[save_attention_to_image(experiments_path, self.attentions[i], self.transformation_counts[i], step, "a%i" % (i+1,)) for i in range(len(self.switches))]
|
|
|
|
def get_debug_values(self, step):
|
|
temp = self.switches[0].switch.temperature
|
|
mean_hists = [compute_attention_specificity(att, 2) for att in self.attentions]
|
|
means = [i[0] for i in mean_hists]
|
|
hists = [i[1].clone().detach().cpu().flatten() for i in mean_hists]
|
|
val = {"switch_temperature": temp}
|
|
for i in range(len(means)):
|
|
val["switch_%i_specificity" % (i,)] = means[i]
|
|
val["switch_%i_histogram" % (i,)] = hists[i]
|
|
return val |