Rev3 of the full image ref arch
This commit is contained in:
parent
f35b3ad28f
commit
8a6a2e6e2e
|
@ -5,7 +5,7 @@ import torch.nn.functional as F
|
|||
from models.archs import SPSR_util as B
|
||||
from .RRDBNet_arch import RRDB
|
||||
from models.archs.arch_util import ConvGnLelu, UpconvBlock, ConjoinBlock
|
||||
from models.archs.SwitchedResidualGenerator_arch import MultiConvBlock, ConvBasisMultiplexer, ConfigurableSwitchComputer, ReferencingConvMultiplexer, ReferenceImageBranch
|
||||
from models.archs.SwitchedResidualGenerator_arch import MultiConvBlock, ConvBasisMultiplexer, ConfigurableSwitchComputer, ReferencingConvMultiplexer, ReferenceImageBranch, AdaInConvBlock
|
||||
from switched_conv_util import save_attention_to_image_rgb
|
||||
from switched_conv import compute_attention_specificity
|
||||
import functools
|
||||
|
@ -364,7 +364,7 @@ class SwitchedSpsrWithRef(nn.Module):
|
|||
self.transformation_counts = xforms
|
||||
self.reference_processor = ReferenceImageBranch(transformation_filters)
|
||||
multiplx_fn = functools.partial(ReferencingConvMultiplexer, transformation_filters, switch_filters, self.transformation_counts)
|
||||
pretransform_fn = functools.partial(ConvGnLelu, transformation_filters, transformation_filters, norm=False, bias=False, weight_init_factor=.1)
|
||||
pretransform_fn = functools.partial(AdaInConvBlock, 512, transformation_filters, transformation_filters)
|
||||
transform_fn = functools.partial(MultiConvBlock, transformation_filters, int(transformation_filters * 1.5),
|
||||
transformation_filters, kernel_size=3, depth=3,
|
||||
weight_init_factor=.1)
|
||||
|
@ -400,11 +400,10 @@ class SwitchedSpsrWithRef(nn.Module):
|
|||
self.grad_branch_output_conv = ConvGnLelu(nf, out_nc, kernel_size=1, norm=False, activation=False, bias=False)
|
||||
|
||||
# Conjoin branch.
|
||||
# Note: "_branch_pretrain" is a special tag used to denote parameters that get pretrained before the rest.
|
||||
transform_fn_cat = functools.partial(MultiConvBlock, transformation_filters * 2, int(transformation_filters * 1.5),
|
||||
transformation_filters, kernel_size=3, depth=4,
|
||||
weight_init_factor=.1)
|
||||
pretransform_fn_cat = functools.partial(ConvGnLelu, transformation_filters * 2, transformation_filters * 2, norm=False, bias=False, weight_init_factor=.1)
|
||||
pretransform_fn_cat = functools.partial(AdaInConvBlock, 512, transformation_filters * 2, transformation_filters * 2)
|
||||
self._branch_pretrain_sw = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
|
||||
pre_transform_block=pretransform_fn_cat, transform_block=transform_fn_cat,
|
||||
attention_norm=True,
|
||||
|
@ -425,20 +424,20 @@ class SwitchedSpsrWithRef(nn.Module):
|
|||
ref = self.reference_processor(ref, center_coord)
|
||||
x = self.model_fea_conv(x)
|
||||
|
||||
x1, a1 = self.sw1(x, True, att_in=(x, ref))
|
||||
x2, a2 = self.sw2(x1, True, att_in=(x, ref))
|
||||
x1, a1 = self.sw1((x, ref), True)
|
||||
x2, a2 = self.sw2((x1, ref), True)
|
||||
x_fea = self.feature_lr_conv(x2)
|
||||
x_fea = self.feature_hr_conv2(x_fea)
|
||||
|
||||
x_b_fea = self.b_fea_conv(x_grad)
|
||||
x_grad, a3 = self.sw_grad(x_b_fea, att_in=(torch.cat([x1, x_b_fea], dim=1), ref), output_attention_weights=True)
|
||||
x_grad, a3 = self.sw_grad((x_b_fea, ref), att_in=(torch.cat([x1, x_b_fea], dim=1), ref), output_attention_weights=True)
|
||||
x_grad = self.grad_lr_conv(x_grad)
|
||||
x_grad = self.grad_hr_conv(x_grad)
|
||||
x_out_branch = self.upsample_grad(x_grad)
|
||||
x_out_branch = self.grad_branch_output_conv(x_out_branch)
|
||||
|
||||
x__branch_pretrain_cat = torch.cat([x_grad, x_fea], dim=1)
|
||||
x__branch_pretrain_cat, a4 = self._branch_pretrain_sw(x__branch_pretrain_cat, att_in=(x_fea, ref), identity=x_fea, output_attention_weights=True)
|
||||
x__branch_pretrain_cat, a4 = self._branch_pretrain_sw((x__branch_pretrain_cat, ref), att_in=(x_fea, ref), identity=x_fea, output_attention_weights=True)
|
||||
x_out = self.final_lr_conv(x__branch_pretrain_cat)
|
||||
x_out = self.upsample(x_out)
|
||||
x_out = self.final_hr_conv1(x_out)
|
||||
|
|
|
@ -4,7 +4,7 @@ from switched_conv import BareConvSwitch, compute_attention_specificity, Attenti
|
|||
import torch.nn.functional as F
|
||||
import functools
|
||||
from collections import OrderedDict
|
||||
from models.archs.arch_util import ConvBnLelu, ConvGnSilu, ExpansionBlock, ExpansionBlock2, ConjoinBlock
|
||||
from models.archs.arch_util import ConvBnLelu, ConvGnSilu, ExpansionBlock, ExpansionBlock2, ConjoinBlock, ConvGnLelu
|
||||
from models.archs.RRDBNet_arch import ResidualDenseBlock_5C, RRDB
|
||||
from models.archs.spinenet_arch import SpineNet
|
||||
from switched_conv_util import save_attention_to_image_rgb
|
||||
|
@ -131,16 +131,32 @@ class ReferenceImageBranch(nn.Module):
|
|||
return gather_2d(x, center_point // 8)
|
||||
|
||||
|
||||
class AdaInConvBlock(nn.Module):
|
||||
def __init__(self, reference_size, in_nc, out_nc, conv_block=ConvGnLelu):
|
||||
super(AdaInConvBlock, self).__init__()
|
||||
self.filter_conv = conv_block(in_nc, out_nc, activation=True, norm=False, bias=False)
|
||||
self.ref_proc = nn.Linear(reference_size, reference_size)
|
||||
self.ref_red = nn.Linear(reference_size, out_nc * 2)
|
||||
self.feature_norm = torch.nn.InstanceNorm2d(out_nc)
|
||||
self.style_norm = torch.nn.InstanceNorm1d(out_nc)
|
||||
self.post_fuse_conv = conv_block(out_nc, out_nc, activation=False, norm=True, bias=True)
|
||||
|
||||
def forward(self, x, ref):
|
||||
x = self.feature_norm(self.filter_conv(x))
|
||||
ref = self.ref_proc(ref)
|
||||
ref = self.ref_red(ref)
|
||||
b, c = ref.shape
|
||||
ref = self.style_norm(ref.view(b, 2, c // 2))
|
||||
x = x * ref[:, 0, :].unsqueeze(dim=2).unsqueeze(dim=3).expand(x.shape) + ref[:, 1, :].unsqueeze(dim=2).unsqueeze(dim=3).expand(x.shape)
|
||||
return self.post_fuse_conv(x)
|
||||
|
||||
|
||||
# This is similar to ConvBasisMultiplexer, except that it takes a linear reference tensor as a second input to
|
||||
# provide better results. It also has fixed parameterization in several places
|
||||
class ReferencingConvMultiplexer(nn.Module):
|
||||
def __init__(self, input_channels, base_filters, multiplexer_channels, use_gn=True):
|
||||
super(ReferencingConvMultiplexer, self).__init__()
|
||||
self.filter_conv = ConvGnSilu(input_channels, base_filters, bias=True)
|
||||
self.ref_proc = nn.Linear(512, 512)
|
||||
self.ref_red = nn.Linear(512, base_filters * 2)
|
||||
self.feature_norm = torch.nn.InstanceNorm2d(base_filters)
|
||||
self.style_norm = torch.nn.InstanceNorm1d(base_filters)
|
||||
self.style_fuse = AdaInConvBlock(512, input_channels, base_filters, ConvGnSilu)
|
||||
|
||||
self.reduction_blocks = nn.ModuleList([HalvingProcessingBlock(base_filters * 2 ** i) for i in range(3)])
|
||||
reduction_filters = base_filters * 2 ** 3
|
||||
|
@ -155,13 +171,7 @@ class ReferencingConvMultiplexer(nn.Module):
|
|||
self.cbl3 = ConvGnSilu(cbl2_out, multiplexer_channels, bias=True, norm=False)
|
||||
|
||||
def forward(self, x, ref):
|
||||
# Start by fusing the reference vector and the input. Follows the ADAIn formula.
|
||||
x = self.feature_norm(self.filter_conv(x))
|
||||
ref = self.ref_proc(ref)
|
||||
ref = self.ref_red(ref)
|
||||
b, c = ref.shape
|
||||
ref = self.style_norm(ref.view(b, 2, c // 2))
|
||||
x = x * ref[:, 0, :].unsqueeze(dim=2).unsqueeze(dim=3).expand(x.shape) + ref[:, 1, :].unsqueeze(dim=2).unsqueeze(dim=3).expand(x.shape)
|
||||
x = self.style_fuse(x, ref)
|
||||
|
||||
reduction_identities = []
|
||||
for b in self.reduction_blocks:
|
||||
|
@ -221,26 +231,43 @@ class ConfigurableSwitchComputer(nn.Module):
|
|||
# depending on its needs.
|
||||
self.psc_scale = nn.Parameter(torch.full((1,), float(.1)))
|
||||
|
||||
|
||||
# Regarding inputs: it is acceptable to pass in a tuple/list as an input for (x), but the first element
|
||||
# *must* be the actual parameter that gets fed through the network - it is assumed to be the identity.
|
||||
def forward(self, x, output_attention_weights=False, identity=None, att_in=None, fixed_scale=1):
|
||||
if isinstance(x, tuple):
|
||||
x1 = x[0]
|
||||
else:
|
||||
x1 = x
|
||||
|
||||
if att_in is None:
|
||||
att_in = x
|
||||
|
||||
if identity is None:
|
||||
identity = x
|
||||
identity = x1
|
||||
|
||||
if self.add_noise:
|
||||
rand_feature = torch.randn_like(x) * self.noise_scale
|
||||
x = x + rand_feature
|
||||
rand_feature = torch.randn_like(x1) * self.noise_scale
|
||||
if isinstance(x, tuple):
|
||||
x = (x1 + rand_feature,) + x[1:]
|
||||
else:
|
||||
x = x1 + rand_feature
|
||||
|
||||
if self.pre_transform:
|
||||
if isinstance(x, tuple):
|
||||
x = self.pre_transform(*x)
|
||||
else:
|
||||
x = self.pre_transform(x)
|
||||
if isinstance(x, tuple):
|
||||
xformed = [t.forward(*x) for t in self.transforms]
|
||||
else:
|
||||
xformed = [t.forward(x) for t in self.transforms]
|
||||
if isinstance(att_in, tuple):
|
||||
m = self.multiplexer(*att_in)
|
||||
else:
|
||||
m = self.multiplexer(att_in)
|
||||
|
||||
|
||||
# It is assumed that [xformed] and [m] are collapsed into tensors at this point.
|
||||
outputs, attention = self.switch(xformed, m, True)
|
||||
outputs = identity + outputs * self.switch_scale * fixed_scale
|
||||
outputs = outputs + self.post_switch_conv(outputs) * self.psc_scale * fixed_scale
|
||||
|
|
Loading…
Reference in New Issue
Block a user