forked from mrq/DL-Art-School
This branches off of SPSR. It is identical but substantially reduced in complexity. It's intended to be my long term working arch.
220 lines
12 KiB
220 lines
12 KiB
import math
import functools
from models.archs.arch_util import MultiConvBlock, ConvGnLelu, ConvGnSilu, ReferenceJoinBlock
from models.archs.SwitchedResidualGenerator_arch import ConfigurableSwitchComputer
from models.archs.SPSR_arch import ImageGradientNoPadding
from torch import nn
import torch
import torch.nn.functional as F
from switched_conv_util import save_attention_to_image_rgb
from switched_conv import compute_attention_specificity
import os
# VGG-style layer with Conv(stride2)->BN->Activation->Conv->BN->Activation
# Doubles the input filter count.
class HalvingProcessingBlock(nn.Module):
def __init__(self, filters):
super(HalvingProcessingBlock, self).__init__()
self.bnconv1 = ConvGnSilu(filters, filters * 2, kernel_size=1, stride=2, norm=False, bias=False)
self.bnconv2 = ConvGnSilu(filters * 2, filters * 2, norm=True, bias=False)
def forward(self, x):
x = self.bnconv1(x)
return self.bnconv2(x)
class ExpansionBlock2(nn.Module):
def __init__(self, filters_in, filters_out=None, block=ConvGnSilu):
super(ExpansionBlock2, self).__init__()
if filters_out is None:
filters_out = filters_in // 2
self.decimate = block(filters_in, filters_out, kernel_size=1, bias=False, activation=True, norm=False)
self.process_passthrough = block(filters_out, filters_out, kernel_size=3, bias=True, activation=True, norm=False)
self.conjoin = block(filters_out*2, filters_out*2, kernel_size=1, bias=False, activation=True, norm=False)
self.reduce = block(filters_out*2, filters_out, kernel_size=1, bias=False, activation=False, norm=True)
# input is the feature signal with shape (b, f, w, h)
# passthrough is the structure signal with shape (b, f/2, w*2, h*2)
# output is conjoined upsample with shape (b, f/2, w*2, h*2)
def forward(self, input, passthrough):
x = F.interpolate(input, scale_factor=2, mode="nearest")
x = self.decimate(x)
p = self.process_passthrough(passthrough)
x = self.conjoin([x, p], dim=1))
return self.reduce(x)
# Basic convolutional upsampling block that uses interpolate.
class UpconvBlock(nn.Module):
def __init__(self, filters_in, filters_out=None, block=ConvGnSilu, norm=True, activation=True, bias=False):
super(UpconvBlock, self).__init__()
self.reduce = block(filters_in, filters_out, kernel_size=1, bias=False, activation=False, norm=False)
self.process = block(filters_out, filters_out, kernel_size=3, bias=bias, activation=activation, norm=norm)
def forward(self, x):
x = self.reduce(x)
x = F.interpolate(x, scale_factor=2, mode="nearest")
return self.process(x)
class SSGMultiplexer(nn.Module):
def __init__(self, nf, multiplexer_channels, reductions=2):
super(SSGMultiplexer, self).__init__()
# Blocks used to create the query
self.input_process = ConvGnSilu(nf, nf, activation=True, norm=False, bias=True)
self.embedding_process = ConvGnSilu(256, 256, activation=True, norm=False, bias=True)
self.reduction_blocks = nn.ModuleList([HalvingProcessingBlock(nf * 2 ** i) for i in range(reductions)])
reduction_filters = nf * 2 ** reductions
self.processing_blocks = nn.Sequential(
ConvGnSilu(reduction_filters + 256, reduction_filters + 128, kernel_size=1, activation=True, norm=False, bias=True),
ConvGnSilu(reduction_filters + 128, reduction_filters, kernel_size=3, activation=True, norm=True, bias=False))
self.expansion_blocks = nn.ModuleList([ExpansionBlock2(reduction_filters // (2 ** i)) for i in range(reductions)])
# Blocks used to create the key
self.key_process = ConvGnSilu(nf, nf, kernel_size=1, activation=True, norm=False, bias=True)
# Postprocessing blocks.
self.query_key_combine = ConvGnSilu(nf*2, nf, kernel_size=1, activation=True, norm=False, bias=False)
self.cbl1 = ConvGnSilu(nf, nf // 4, kernel_size=1, activation=True, norm=True, bias=False, num_groups=4)
self.cbl2 = ConvGnSilu(nf // 4, 1, kernel_size=1, activation=False, norm=False, bias=False)
def forward(self, x, embedding, transformations):
q = self.input_process(x)
embedding = self.embedding_process(embedding)
reduction_identities = []
for b in self.reduction_blocks:
q = b(q)
q = self.processing_blocks([q, embedding], dim=1))
for i, b in enumerate(self.expansion_blocks):
q = b(q, reduction_identities[-i - 1])
b, t, f, h, w = transformations.shape
k = transformations.view(b * t, f, h, w)
k = self.key_process(k)
q = q.view(b, 1, f, h, w).repeat(1, t, 1, 1, 1).view(b * t, f, h, w)
v = self.query_key_combine([q, k], dim=1))
v = self.cbl1(v)
v = self.cbl2(v)
return v.view(b, t, h, w)
class SSGr1(nn.Module):
def __init__(self, in_nc, out_nc, nf, xforms=8, upscale=4, init_temperature=10):
super(SSGr1, self).__init__()
n_upscale = int(math.log(upscale, 2))
# switch options
transformation_filters = nf
self.transformation_counts = xforms
multiplx_fn = functools.partial(SSGMultiplexer, transformation_filters)
transform_fn = functools.partial(MultiConvBlock, transformation_filters, int(transformation_filters * 1.25),
transformation_filters, kernel_size=3, depth=3,
# Feature branch
self.model_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False)
self.noise_ref_join = ReferenceJoinBlock(nf, residual_weight_init_factor=.1, kernel_size=1, depth=2)
self.sw1 = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
pre_transform_block=None, transform_block=transform_fn,
transform_count=self.transformation_counts, init_temp=init_temperature,
add_scalable_noise_to_transforms=False, feed_transforms_into_multiplexer=True)
self.feature_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=True, activation=False)
self.feature_lr_conv2 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False, bias=False)
# Grad branch. Note - groupnorm on this branch is REALLY bad. Avoid it like the plague.
self.get_g_nopadding = ImageGradientNoPadding()
self.grad_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False, bias=False)
self.noise_ref_join_grad = ReferenceJoinBlock(nf, residual_weight_init_factor=.1, kernel_size=1, depth=2)
self.grad_ref_join = ReferenceJoinBlock(nf, residual_weight_init_factor=.3, final_norm=False, kernel_size=1, depth=2)
self.sw_grad = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
pre_transform_block=None, transform_block=transform_fn,
transform_count=self.transformation_counts // 2, init_temp=init_temperature,
add_scalable_noise_to_transforms=False, feed_transforms_into_multiplexer=True)
self.grad_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=True)
self.upsample_grad = nn.Sequential(*[UpconvBlock(nf, nf // 2, block=ConvGnLelu, norm=False, activation=True, bias=False) for _ in range(n_upscale)])
self.grad_branch_output_conv = ConvGnLelu(nf // 2, out_nc, kernel_size=1, norm=False, activation=False, bias=True)
# Join branch (grad+fea)
self.noise_ref_join_conjoin = ReferenceJoinBlock(nf, residual_weight_init_factor=.1, kernel_size=1, depth=2)
self.conjoin_ref_join = ReferenceJoinBlock(nf, residual_weight_init_factor=.3, kernel_size=1, depth=2)
self.conjoin_sw = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
pre_transform_block=None, transform_block=transform_fn,
transform_count=self.transformation_counts, init_temp=init_temperature,
add_scalable_noise_to_transforms=False, feed_transforms_into_multiplexer=True)
self.final_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=True)
self.upsample = nn.Sequential(*[UpconvBlock(nf, nf // 2, block=ConvGnLelu, norm=False, activation=True, bias=True) for _ in range(n_upscale)])
self.final_hr_conv2 = ConvGnLelu(nf // 2, out_nc, kernel_size=3, norm=False, activation=False, bias=False)
self.switches = [self.sw1, self.sw_grad, self.conjoin_sw]
self.attentions = None
self.init_temperature = init_temperature
self.final_temperature_step = 10000
def forward(self, x, embedding):
noise_stds = []
x_grad = self.get_g_nopadding(x)
x = self.model_fea_conv(x)
x1 = x
x1, a1 = self.sw1(x1, True, identity=x, att_in=(x1, embedding))
x_grad = self.grad_conv(x_grad)
x_grad_identity = x_grad
x_grad, nstd = self.noise_ref_join_grad(x_grad, torch.randn_like(x_grad))
x_grad, grad_fea_std = self.grad_ref_join(x_grad, x1)
x_grad, a3 = self.sw_grad(x_grad, True, identity=x_grad_identity, att_in=(x_grad, embedding))
x_grad = self.grad_lr_conv(x_grad)
x_grad_out = self.upsample_grad(x_grad)
x_grad_out = self.grad_branch_output_conv(x_grad_out)
x_out = x1
x_out, nstd = self.noise_ref_join_conjoin(x_out, torch.randn_like(x_out))
x_out, fea_grad_std = self.conjoin_ref_join(x_out, x_grad)
x_out, a4 = self.conjoin_sw(x_out, True, identity=x1, att_in=(x_out, embedding))
x_out = self.final_lr_conv(x_out)
x_out = self.upsample(x_out)
x_out = self.final_hr_conv2(x_out)
self.attentions = [a1, a3, a4]
self.noise_stds = torch.stack(noise_stds).mean().detach().cpu()
self.grad_fea_std = grad_fea_std.detach().cpu()
self.fea_grad_std = fea_grad_std.detach().cpu()
return x_grad_out, x_out, x_grad
def set_temperature(self, temp):
[sw.set_temperature(temp) for sw in self.switches]
def update_for_step(self, step, experiments_path='.'):
if self.attentions:
temp = max(1, 1 + self.init_temperature *
(self.final_temperature_step - step) / self.final_temperature_step)
if step % 200 == 0:
output_path = os.path.join(experiments_path, "attention_maps", "a%i")
prefix = "attention_map_%i_%%i.png" % (step,)
[save_attention_to_image_rgb(output_path % (i,), self.attentions[i], self.transformation_counts, prefix, step) for i in range(len(self.attentions))]
def get_debug_values(self, step):
temp = self.switches[0].switch.temperature
mean_hists = [compute_attention_specificity(att, 2) for att in self.attentions]
means = [i[0] for i in mean_hists]
hists = [i[1].clone().detach().cpu().flatten() for i in mean_hists]
val = {"switch_temperature": temp,
"noise_branch_std_dev": self.noise_stds,
"grad_branch_feat_intg_std_dev": self.grad_fea_std,
"conjoin_branch_grad_intg_std_dev": self.fea_grad_std}
for i in range(len(means)):
val["switch_%i_specificity" % (i,)] = means[i]
val["switch_%i_histogram" % (i,)] = hists[i]
return val