Rev3 of the full image ref arch

This commit is contained in:
James Betker 2020-08-26 17:11:01 -06:00
parent f35b3ad28f
commit 8a6a2e6e2e
2 changed files with 53 additions and 27 deletions

View File

@ -5,7 +5,7 @@ import torch.nn.functional as F
from models.archs import SPSR_util as B
from .RRDBNet_arch import RRDB
from models.archs.arch_util import ConvGnLelu, UpconvBlock, ConjoinBlock
from models.archs.SwitchedResidualGenerator_arch import MultiConvBlock, ConvBasisMultiplexer, ConfigurableSwitchComputer, ReferencingConvMultiplexer, ReferenceImageBranch
from models.archs.SwitchedResidualGenerator_arch import MultiConvBlock, ConvBasisMultiplexer, ConfigurableSwitchComputer, ReferencingConvMultiplexer, ReferenceImageBranch, AdaInConvBlock
from switched_conv_util import save_attention_to_image_rgb
from switched_conv import compute_attention_specificity
import functools
@ -364,7 +364,7 @@ class SwitchedSpsrWithRef(nn.Module):
self.transformation_counts = xforms
self.reference_processor = ReferenceImageBranch(transformation_filters)
multiplx_fn = functools.partial(ReferencingConvMultiplexer, transformation_filters, switch_filters, self.transformation_counts)
pretransform_fn = functools.partial(ConvGnLelu, transformation_filters, transformation_filters, norm=False, bias=False, weight_init_factor=.1)
pretransform_fn = functools.partial(AdaInConvBlock, 512, transformation_filters, transformation_filters)
transform_fn = functools.partial(MultiConvBlock, transformation_filters, int(transformation_filters * 1.5),
transformation_filters, kernel_size=3, depth=3,
weight_init_factor=.1)
@ -400,11 +400,10 @@ class SwitchedSpsrWithRef(nn.Module):
self.grad_branch_output_conv = ConvGnLelu(nf, out_nc, kernel_size=1, norm=False, activation=False, bias=False)
# Conjoin branch.
# Note: "_branch_pretrain" is a special tag used to denote parameters that get pretrained before the rest.
transform_fn_cat = functools.partial(MultiConvBlock, transformation_filters * 2, int(transformation_filters * 1.5),
transformation_filters, kernel_size=3, depth=4,
weight_init_factor=.1)
pretransform_fn_cat = functools.partial(ConvGnLelu, transformation_filters * 2, transformation_filters * 2, norm=False, bias=False, weight_init_factor=.1)
pretransform_fn_cat = functools.partial(AdaInConvBlock, 512, transformation_filters * 2, transformation_filters * 2)
self._branch_pretrain_sw = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
pre_transform_block=pretransform_fn_cat, transform_block=transform_fn_cat,
attention_norm=True,
@ -425,20 +424,20 @@ class SwitchedSpsrWithRef(nn.Module):
ref = self.reference_processor(ref, center_coord)
x = self.model_fea_conv(x)
x1, a1 = self.sw1(x, True, att_in=(x, ref))
x2, a2 = self.sw2(x1, True, att_in=(x, ref))
x1, a1 = self.sw1((x, ref), True)
x2, a2 = self.sw2((x1, ref), True)
x_fea = self.feature_lr_conv(x2)
x_fea = self.feature_hr_conv2(x_fea)
x_b_fea = self.b_fea_conv(x_grad)
x_grad, a3 = self.sw_grad(x_b_fea, att_in=(torch.cat([x1, x_b_fea], dim=1), ref), output_attention_weights=True)
x_grad, a3 = self.sw_grad((x_b_fea, ref), att_in=(torch.cat([x1, x_b_fea], dim=1), ref), output_attention_weights=True)
x_grad = self.grad_lr_conv(x_grad)
x_grad = self.grad_hr_conv(x_grad)
x_out_branch = self.upsample_grad(x_grad)
x_out_branch = self.grad_branch_output_conv(x_out_branch)
x__branch_pretrain_cat = torch.cat([x_grad, x_fea], dim=1)
x__branch_pretrain_cat, a4 = self._branch_pretrain_sw(x__branch_pretrain_cat, att_in=(x_fea, ref), identity=x_fea, output_attention_weights=True)
x__branch_pretrain_cat, a4 = self._branch_pretrain_sw((x__branch_pretrain_cat, ref), att_in=(x_fea, ref), identity=x_fea, output_attention_weights=True)
x_out = self.final_lr_conv(x__branch_pretrain_cat)
x_out = self.upsample(x_out)
x_out = self.final_hr_conv1(x_out)

View File

@ -4,7 +4,7 @@ from switched_conv import BareConvSwitch, compute_attention_specificity, Attenti
import torch.nn.functional as F
import functools
from collections import OrderedDict
from models.archs.arch_util import ConvBnLelu, ConvGnSilu, ExpansionBlock, ExpansionBlock2, ConjoinBlock
from models.archs.arch_util import ConvBnLelu, ConvGnSilu, ExpansionBlock, ExpansionBlock2, ConjoinBlock, ConvGnLelu
from models.archs.RRDBNet_arch import ResidualDenseBlock_5C, RRDB
from models.archs.spinenet_arch import SpineNet
from switched_conv_util import save_attention_to_image_rgb
@ -131,16 +131,32 @@ class ReferenceImageBranch(nn.Module):
return gather_2d(x, center_point // 8)
class AdaInConvBlock(nn.Module):
def __init__(self, reference_size, in_nc, out_nc, conv_block=ConvGnLelu):
super(AdaInConvBlock, self).__init__()
self.filter_conv = conv_block(in_nc, out_nc, activation=True, norm=False, bias=False)
self.ref_proc = nn.Linear(reference_size, reference_size)
self.ref_red = nn.Linear(reference_size, out_nc * 2)
self.feature_norm = torch.nn.InstanceNorm2d(out_nc)
self.style_norm = torch.nn.InstanceNorm1d(out_nc)
self.post_fuse_conv = conv_block(out_nc, out_nc, activation=False, norm=True, bias=True)
def forward(self, x, ref):
x = self.feature_norm(self.filter_conv(x))
ref = self.ref_proc(ref)
ref = self.ref_red(ref)
b, c = ref.shape
ref = self.style_norm(ref.view(b, 2, c // 2))
x = x * ref[:, 0, :].unsqueeze(dim=2).unsqueeze(dim=3).expand(x.shape) + ref[:, 1, :].unsqueeze(dim=2).unsqueeze(dim=3).expand(x.shape)
return self.post_fuse_conv(x)
# This is similar to ConvBasisMultiplexer, except that it takes a linear reference tensor as a second input to
# provide better results. It also has fixed parameterization in several places
class ReferencingConvMultiplexer(nn.Module):
def __init__(self, input_channels, base_filters, multiplexer_channels, use_gn=True):
super(ReferencingConvMultiplexer, self).__init__()
self.filter_conv = ConvGnSilu(input_channels, base_filters, bias=True)
self.ref_proc = nn.Linear(512, 512)
self.ref_red = nn.Linear(512, base_filters * 2)
self.feature_norm = torch.nn.InstanceNorm2d(base_filters)
self.style_norm = torch.nn.InstanceNorm1d(base_filters)
self.style_fuse = AdaInConvBlock(512, input_channels, base_filters, ConvGnSilu)
self.reduction_blocks = nn.ModuleList([HalvingProcessingBlock(base_filters * 2 ** i) for i in range(3)])
reduction_filters = base_filters * 2 ** 3
@ -155,13 +171,7 @@ class ReferencingConvMultiplexer(nn.Module):
self.cbl3 = ConvGnSilu(cbl2_out, multiplexer_channels, bias=True, norm=False)
def forward(self, x, ref):
# Start by fusing the reference vector and the input. Follows the ADAIn formula.
x = self.feature_norm(self.filter_conv(x))
ref = self.ref_proc(ref)
ref = self.ref_red(ref)
b, c = ref.shape
ref = self.style_norm(ref.view(b, 2, c // 2))
x = x * ref[:, 0, :].unsqueeze(dim=2).unsqueeze(dim=3).expand(x.shape) + ref[:, 1, :].unsqueeze(dim=2).unsqueeze(dim=3).expand(x.shape)
x = self.style_fuse(x, ref)
reduction_identities = []
for b in self.reduction_blocks:
@ -221,26 +231,43 @@ class ConfigurableSwitchComputer(nn.Module):
# depending on its needs.
self.psc_scale = nn.Parameter(torch.full((1,), float(.1)))
# Regarding inputs: it is acceptable to pass in a tuple/list as an input for (x), but the first element
# *must* be the actual parameter that gets fed through the network - it is assumed to be the identity.
def forward(self, x, output_attention_weights=False, identity=None, att_in=None, fixed_scale=1):
if isinstance(x, tuple):
x1 = x[0]
else:
x1 = x
if att_in is None:
att_in = x
if identity is None:
identity = x
identity = x1
if self.add_noise:
rand_feature = torch.randn_like(x) * self.noise_scale
x = x + rand_feature
rand_feature = torch.randn_like(x1) * self.noise_scale
if isinstance(x, tuple):
x = (x1 + rand_feature,) + x[1:]
else:
x = x1 + rand_feature
if self.pre_transform:
if isinstance(x, tuple):
x = self.pre_transform(*x)
else:
x = self.pre_transform(x)
if isinstance(x, tuple):
xformed = [t.forward(*x) for t in self.transforms]
else:
xformed = [t.forward(x) for t in self.transforms]
if isinstance(att_in, tuple):
m = self.multiplexer(*att_in)
else:
m = self.multiplexer(att_in)
# It is assumed that [xformed] and [m] are collapsed into tensors at this point.
outputs, attention = self.switch(xformed, m, True)
outputs = identity + outputs * self.switch_scale * fixed_scale
outputs = outputs + self.post_switch_conv(outputs) * self.psc_scale * fixed_scale