DL-Art-School/codes/models/archs/SwitchedResidualGenerator_arch.py
James Betker 5f2c722a10 SRG2 revival
Big update to SRG2 architecture to pull in a lot of things that have been learned:
- Use group norm instead of batch norm
- Initialize the weights on the transformations low like is done in RRDB rather than using the scalar. Models live or die by their early stages, and this ones early stage is pretty weak
- Transform multiplexer to use u-net like architecture.
- Just use one set of configuration variables instead of a list - flat networks performed fine in this regard.
2020-07-09 17:34:51 -06:00

335 lines
17 KiB
Python

import torch
from torch import nn
from switched_conv import BareConvSwitch, compute_attention_specificity
import torch.nn.functional as F
import functools
from collections import OrderedDict
from models.archs.arch_util import ConvBnLelu, ConvGnSilu
from models.archs.RRDBNet_arch import ResidualDenseBlock_5C
from models.archs.spinenet_arch import SpineNet
from switched_conv_util import save_attention_to_image
class MultiConvBlock(nn.Module):
def __init__(self, filters_in, filters_mid, filters_out, kernel_size, depth, scale_init=1, bn=False, weight_init_factor=1):
assert depth >= 2
super(MultiConvBlock, self).__init__()
self.noise_scale = nn.Parameter(torch.full((1,), fill_value=.01))
self.bnconvs = nn.ModuleList([ConvBnLelu(filters_in, filters_mid, kernel_size, bn=bn, bias=False, weight_init_factor=weight_init_factor)] +
[ConvBnLelu(filters_mid, filters_mid, kernel_size, bn=bn, bias=False, weight_init_factor=weight_init_factor) for i in range(depth-2)] +
[ConvBnLelu(filters_mid, filters_out, kernel_size, lelu=False, bn=False, bias=False, weight_init_factor=weight_init_factor)])
self.scale = nn.Parameter(torch.full((1,), fill_value=scale_init))
self.bias = nn.Parameter(torch.zeros(1))
def forward(self, x, noise=None):
if noise is not None:
noise = noise * self.noise_scale
x = x + noise
for m in self.bnconvs:
x = m.forward(x)
return x * self.scale + self.bias
# VGG-style layer with Conv(stride2)->BN->Activation->Conv->BN->Activation
# Doubles the input filter count.
class HalvingProcessingBlock(nn.Module):
def __init__(self, filters):
super(HalvingProcessingBlock, self).__init__()
self.bnconv1 = ConvGnSilu(filters, filters * 2, stride=2, gn=False, bias=False)
self.bnconv2 = ConvGnSilu(filters * 2, filters * 2, gn=True, bias=False)
def forward(self, x):
x = self.bnconv1(x)
return self.bnconv2(x)
class ExpansionBlock(nn.Module):
def __init__(self, filters):
super(ExpansionBlock, self).__init__()
self.decimate = ConvGnSilu(filters, filters // 2, kernel_size=1, bias=False, silu=False, gn=False)
self.conjoin = ConvGnSilu(filters, filters // 2, kernel_size=3, bias=True, silu=False, gn=True)
self.process = ConvGnSilu(filters // 2, filters // 2, kernel_size=3, bias=False, silu=True, gn=True)
def forward(self, input, passthrough):
x = F.interpolate(input, scale_factor=2, mode="nearest")
x = self.decimate(x)
x = self.conjoin(torch.cat([x, passthrough], dim=1))
return self.process(x)
# This is a classic u-net architecture with the goal of assigning each individual pixel an individual transform
# switching set.
class ConvBasisMultiplexer(nn.Module):
def __init__(self, input_channels, base_filters, reductions, processing_depth, multiplexer_channels, use_gn=True):
super(ConvBasisMultiplexer, self).__init__()
self.filter_conv = ConvGnSilu(input_channels, base_filters, bias=True)
self.reduction_blocks = nn.ModuleList([HalvingProcessingBlock(base_filters * 2 ** i) for i in range(reductions)])
reduction_filters = base_filters * 2 ** reductions
self.processing_blocks = nn.Sequential(OrderedDict([('block%i' % (i,), ConvGnSilu(reduction_filters, reduction_filters, bias=False)) for i in range(processing_depth)]))
self.expansion_blocks = nn.ModuleList([ExpansionBlock(reduction_filters // (2 ** i)) for i in range(reductions)])
gap = base_filters - multiplexer_channels
cbl1_out = ((base_filters - (gap // 2)) // 4) * 4 # Must be multiples of 4 to use with group norm.
self.cbl1 = ConvGnSilu(base_filters, cbl1_out, gn=use_gn, bias=False, num_groups=4)
cbl2_out = ((base_filters - (3 * gap // 4)) // 4) * 4
self.cbl2 = ConvGnSilu(cbl1_out, cbl2_out, gn=use_gn, bias=False, num_groups=4)
self.cbl3 = ConvGnSilu(cbl2_out, multiplexer_channels, bias=True, gn=False)
def forward(self, x):
x = self.filter_conv(x)
reduction_identities = []
for b in self.reduction_blocks:
reduction_identities.append(x)
x = b(x)
x = self.processing_blocks(x)
for i, b in enumerate(self.expansion_blocks):
x = b(x, reduction_identities[-i - 1])
x = self.cbl1(x)
x = self.cbl2(x)
x = self.cbl3(x)
return x
class CachedBackboneWrapper:
def __init__(self, backbone: nn.Module):
self.backbone = backbone
def __call__(self, *args):
self.cache = self.backbone(*args)
return self.cache
def get_forward_result(self):
return self.cache
class BackboneMultiplexer(nn.Module):
def __init__(self, backbone: CachedBackboneWrapper, transform_count):
super(BackboneMultiplexer, self).__init__()
self.backbone = backbone
self.proc = nn.Sequential(ConvGnSilu(256, 256, kernel_size=3, bias=True),
ConvGnSilu(256, 256, kernel_size=3, bias=False))
self.up1 = nn.Sequential(ConvGnSilu(256, 128, kernel_size=3, bias=False, gn=False, silu=False),
ConvGnSilu(128, 128, kernel_size=3, bias=False))
self.up2 = nn.Sequential(ConvGnSilu(128, 64, kernel_size=3, bias=False, gn=False, silu=False),
ConvGnSilu(64, 64, kernel_size=3, bias=False))
self.final = ConvGnSilu(64, transform_count, bias=False, gn=False, silu=False)
def forward(self, x):
spine = self.backbone.get_forward_result()
feat = self.proc(spine[0])
feat = self.up1(F.interpolate(feat, scale_factor=2, mode="nearest"))
feat = self.up2(F.interpolate(feat, scale_factor=2, mode="nearest"))
return self.final(feat)
class ConfigurableSwitchComputer(nn.Module):
def __init__(self, base_filters, multiplexer_net, pre_transform_block, transform_block, transform_count, init_temp=20,
add_scalable_noise_to_transforms=False):
super(ConfigurableSwitchComputer, self).__init__()
tc = transform_count
self.multiplexer = multiplexer_net(tc)
self.pre_transform = pre_transform_block()
self.transforms = nn.ModuleList([transform_block() for _ in range(transform_count)])
self.add_noise = add_scalable_noise_to_transforms
self.noise_scale = nn.Parameter(torch.full((1,), float(1e-3)))
# And the switch itself, including learned scalars
self.switch = BareConvSwitch(initial_temperature=init_temp)
self.switch_scale = nn.Parameter(torch.full((1,), float(1)))
self.post_switch_conv = ConvBnLelu(base_filters, base_filters, bn=False, bias=True)
# The post_switch_conv gets a low scale initially. The network can decide to magnify it (or not)
# depending on its needs.
self.psc_scale = nn.Parameter(torch.full((1,), float(.1)))
def forward(self, x, output_attention_weights=False):
identity = x
if self.add_noise:
rand_feature = torch.randn_like(x) * self.noise_scale
x = x + rand_feature
x = self.pre_transform(x)
xformed = [t.forward(x) for t in self.transforms]
m = self.multiplexer(identity)
outputs, attention = self.switch(xformed, m, True)
outputs = identity + outputs * self.switch_scale
outputs = outputs + self.post_switch_conv(outputs) * self.psc_scale
if output_attention_weights:
return outputs, attention
else:
return outputs
def set_temperature(self, temp):
self.switch.set_attention_temperature(temp)
class ConfigurableSwitchedResidualGenerator2(nn.Module):
def __init__(self, switch_depth, switch_filters, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes,
trans_layers, transformation_filters, initial_temp=20, final_temperature_step=50000, heightened_temp_min=1,
heightened_final_step=50000, upsample_factor=1,
add_scalable_noise_to_transforms=False):
super(ConfigurableSwitchedResidualGenerator2, self).__init__()
switches = []
self.initial_conv = ConvBnLelu(3, transformation_filters, bn=False, lelu=False, bias=True)
self.upconv1 = ConvBnLelu(transformation_filters, transformation_filters, bn=False, bias=True)
self.upconv2 = ConvBnLelu(transformation_filters, transformation_filters, bn=False, bias=True)
self.hr_conv = ConvBnLelu(transformation_filters, transformation_filters, bn=False, bias=True)
self.final_conv = ConvBnLelu(transformation_filters, 3, bn=False, lelu=False, bias=True)
for _ in range(switch_depth):
multiplx_fn = functools.partial(ConvBasisMultiplexer, transformation_filters, switch_filters, switch_reductions, switch_processing_layers, trans_counts)
pretransform_fn = functools.partial(ConvBnLelu, transformation_filters, transformation_filters, bn=False, bias=False, weight_init_factor=.1)
transform_fn = functools.partial(MultiConvBlock, transformation_filters, int(transformation_filters * 1.5), transformation_filters, kernel_size=trans_kernel_sizes, depth=trans_layers, weight_init_factor=.1)
switches.append(ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
pre_transform_block=pretransform_fn, transform_block=transform_fn,
transform_count=trans_counts, init_temp=initial_temp,
add_scalable_noise_to_transforms=add_scalable_noise_to_transforms))
self.switches = nn.ModuleList(switches)
self.transformation_counts = trans_counts
self.init_temperature = initial_temp
self.final_temperature_step = final_temperature_step
self.heightened_temp_min = heightened_temp_min
self.heightened_final_step = heightened_final_step
self.attentions = None
self.upsample_factor = upsample_factor
assert self.upsample_factor == 2 or self.upsample_factor == 4
def forward(self, x):
x = self.initial_conv(x)
self.attentions = []
for i, sw in enumerate(self.switches):
x, att = sw.forward(x, True)
self.attentions.append(att)
x = self.upconv1(F.interpolate(x, scale_factor=2, mode="nearest"))
if self.upsample_factor > 2:
x = F.interpolate(x, scale_factor=2, mode="nearest")
x = self.upconv2(x)
return self.final_conv(self.hr_conv(x)),
def set_temperature(self, temp):
[sw.set_temperature(temp) for sw in self.switches]
def update_for_step(self, step, experiments_path='.'):
if self.attentions:
temp = max(1, int(self.init_temperature * (self.final_temperature_step - step) / self.final_temperature_step))
if temp == 1 and self.heightened_final_step and self.heightened_final_step != 1:
# Once the temperature passes (1) it enters an inverted curve to match the linear curve from above.
# without this, the attention specificity "spikes" incredibly fast in the last few iterations.
h_steps_total = self.heightened_final_step - self.final_temperature_step
h_steps_current = min(step - self.final_temperature_step, h_steps_total)
# The "gap" will represent the steps that need to be traveled as a linear function.
h_gap = 1 / self.heightened_temp_min
temp = h_gap * h_steps_current / h_steps_total
# Invert temperature to represent reality on this side of the curve
temp = 1 / temp
self.set_temperature(temp)
if step % 50 == 0:
[save_attention_to_image(experiments_path, self.attentions[i], self.transformation_counts[i], step, "a%i" % (i+1,)) for i in range(len(self.switches))]
def get_debug_values(self, step):
temp = self.switches[0].switch.temperature
mean_hists = [compute_attention_specificity(att, 2) for att in self.attentions]
means = [i[0] for i in mean_hists]
hists = [i[1].clone().detach().cpu().flatten() for i in mean_hists]
val = {"switch_temperature": temp}
for i in range(len(means)):
val["switch_%i_specificity" % (i,)] = means[i]
val["switch_%i_histogram" % (i,)] = hists[i]
return val
class Interpolate(nn.Module):
def __init__(self, factor):
super(Interpolate, self).__init__()
self.factor = factor
def forward(self, x):
return F.interpolate(x, scale_factor=self.factor)
class ConfigurableSwitchedResidualGenerator3(nn.Module):
def __init__(self, base_filters, trans_count, initial_temp=20, final_temperature_step=50000,
heightened_temp_min=1,
heightened_final_step=50000, upsample_factor=4):
super(ConfigurableSwitchedResidualGenerator3, self).__init__()
self.initial_conv = ConvBnLelu(3, base_filters, bn=False, lelu=False, bias=True)
self.sw_conv = ConvBnLelu(base_filters, base_filters, lelu=False, bias=True)
self.upconv1 = ConvBnLelu(base_filters, base_filters, bn=False, bias=True)
self.upconv2 = ConvBnLelu(base_filters, base_filters, bn=False, bias=True)
self.hr_conv = ConvBnLelu(base_filters, base_filters, bn=False, bias=True)
self.final_conv = ConvBnLelu(base_filters, 3, bn=False, lelu=False, bias=True)
self.backbone = SpineNet('49', in_channels=3, use_input_norm=True)
for p in self.backbone.parameters(recurse=True):
p.requires_grad = False
self.backbone_wrapper = CachedBackboneWrapper(self.backbone)
multiplx_fn = functools.partial(BackboneMultiplexer, self.backbone_wrapper)
pretransform_fn = functools.partial(nn.Sequential, ConvBnLelu(base_filters, base_filters, kernel_size=3, bn=False, lelu=False, bias=False))
transform_fn = functools.partial(MultiConvBlock, base_filters, int(base_filters * 1.5), base_filters, kernel_size=3, depth=4)
self.switch = ConfigurableSwitchComputer(base_filters, multiplx_fn, pretransform_fn, transform_fn, trans_count, init_temp=initial_temp,
add_scalable_noise_to_transforms=True, init_scalar=.1)
self.transformation_counts = trans_count
self.init_temperature = initial_temp
self.final_temperature_step = final_temperature_step
self.heightened_temp_min = heightened_temp_min
self.heightened_final_step = heightened_final_step
self.attentions = None
self.upsample_factor = upsample_factor
self.backbone_forward = None
def get_forward_results(self):
return self.backbone_forward
def forward(self, x):
self.backbone_forward = self.backbone_wrapper(F.interpolate(x, scale_factor=2, mode="nearest"))
x = self.initial_conv(x)
self.attentions = []
x, att = self.switch(x, output_attention_weights=True)
self.attentions.append(att)
x = self.upconv1(F.interpolate(x, scale_factor=2, mode="nearest"))
if self.upsample_factor > 2:
x = F.interpolate(x, scale_factor=2, mode="nearest")
x = self.upconv2(x)
return self.final_conv(self.hr_conv(x)),
def set_temperature(self, temp):
self.switch.set_temperature(temp)
def update_for_step(self, step, experiments_path='.'):
if self.attentions:
temp = max(1, int(
self.init_temperature * (self.final_temperature_step - step) / self.final_temperature_step))
if temp == 1 and self.heightened_final_step and self.heightened_final_step != 1:
# Once the temperature passes (1) it enters an inverted curve to match the linear curve from above.
# without this, the attention specificity "spikes" incredibly fast in the last few iterations.
h_steps_total = self.heightened_final_step - self.final_temperature_step
h_steps_current = min(step - self.final_temperature_step, h_steps_total)
# The "gap" will represent the steps that need to be traveled as a linear function.
h_gap = 1 / self.heightened_temp_min
temp = h_gap * h_steps_current / h_steps_total
# Invert temperature to represent reality on this side of the curve
temp = 1 / temp
self.set_temperature(temp)
if step % 50 == 0:
save_attention_to_image(experiments_path, self.attentions[0], self.transformation_counts, step, "a%i" % (1,), l_mult=10)
def get_debug_values(self, step):
temp = self.switch.switch.temperature
mean_hists = [compute_attention_specificity(att, 2) for att in self.attentions]
means = [i[0] for i in mean_hists]
hists = [i[1].clone().detach().cpu().flatten() for i in mean_hists]
val = {"switch_temperature": temp}
for i in range(len(means)):
val["switch_%i_specificity" % (i,)] = means[i]
val["switch_%i_histogram" % (i,)] = hists[i]
return val