diff --git a/codes/models/archs/SPSR_arch.py b/codes/models/archs/SPSR_arch.py index ec0c616b..29996bd9 100644 --- a/codes/models/archs/SPSR_arch.py +++ b/codes/models/archs/SPSR_arch.py @@ -5,7 +5,7 @@ import torch.nn.functional as F from models.archs import SPSR_util as B from .RRDBNet_arch import RRDB from models.archs.arch_util import ConvGnLelu, UpconvBlock, ConjoinBlock -from models.archs.SwitchedResidualGenerator_arch import MultiConvBlock, ConvBasisMultiplexer, ConfigurableSwitchComputer, ReferencingConvMultiplexer, ReferenceImageBranch +from models.archs.SwitchedResidualGenerator_arch import MultiConvBlock, ConvBasisMultiplexer, ConfigurableSwitchComputer, ReferencingConvMultiplexer, ReferenceImageBranch, AdaInConvBlock from switched_conv_util import save_attention_to_image_rgb from switched_conv import compute_attention_specificity import functools @@ -364,7 +364,7 @@ class SwitchedSpsrWithRef(nn.Module): self.transformation_counts = xforms self.reference_processor = ReferenceImageBranch(transformation_filters) multiplx_fn = functools.partial(ReferencingConvMultiplexer, transformation_filters, switch_filters, self.transformation_counts) - pretransform_fn = functools.partial(ConvGnLelu, transformation_filters, transformation_filters, norm=False, bias=False, weight_init_factor=.1) + pretransform_fn = functools.partial(AdaInConvBlock, 512, transformation_filters, transformation_filters) transform_fn = functools.partial(MultiConvBlock, transformation_filters, int(transformation_filters * 1.5), transformation_filters, kernel_size=3, depth=3, weight_init_factor=.1) @@ -400,11 +400,10 @@ class SwitchedSpsrWithRef(nn.Module): self.grad_branch_output_conv = ConvGnLelu(nf, out_nc, kernel_size=1, norm=False, activation=False, bias=False) # Conjoin branch. - # Note: "_branch_pretrain" is a special tag used to denote parameters that get pretrained before the rest. transform_fn_cat = functools.partial(MultiConvBlock, transformation_filters * 2, int(transformation_filters * 1.5), transformation_filters, kernel_size=3, depth=4, weight_init_factor=.1) - pretransform_fn_cat = functools.partial(ConvGnLelu, transformation_filters * 2, transformation_filters * 2, norm=False, bias=False, weight_init_factor=.1) + pretransform_fn_cat = functools.partial(AdaInConvBlock, 512, transformation_filters * 2, transformation_filters * 2) self._branch_pretrain_sw = ConfigurableSwitchComputer(transformation_filters, multiplx_fn, pre_transform_block=pretransform_fn_cat, transform_block=transform_fn_cat, attention_norm=True, @@ -425,20 +424,20 @@ class SwitchedSpsrWithRef(nn.Module): ref = self.reference_processor(ref, center_coord) x = self.model_fea_conv(x) - x1, a1 = self.sw1(x, True, att_in=(x, ref)) - x2, a2 = self.sw2(x1, True, att_in=(x, ref)) + x1, a1 = self.sw1((x, ref), True) + x2, a2 = self.sw2((x1, ref), True) x_fea = self.feature_lr_conv(x2) x_fea = self.feature_hr_conv2(x_fea) x_b_fea = self.b_fea_conv(x_grad) - x_grad, a3 = self.sw_grad(x_b_fea, att_in=(torch.cat([x1, x_b_fea], dim=1), ref), output_attention_weights=True) + x_grad, a3 = self.sw_grad((x_b_fea, ref), att_in=(torch.cat([x1, x_b_fea], dim=1), ref), output_attention_weights=True) x_grad = self.grad_lr_conv(x_grad) x_grad = self.grad_hr_conv(x_grad) x_out_branch = self.upsample_grad(x_grad) x_out_branch = self.grad_branch_output_conv(x_out_branch) x__branch_pretrain_cat = torch.cat([x_grad, x_fea], dim=1) - x__branch_pretrain_cat, a4 = self._branch_pretrain_sw(x__branch_pretrain_cat, att_in=(x_fea, ref), identity=x_fea, output_attention_weights=True) + x__branch_pretrain_cat, a4 = self._branch_pretrain_sw((x__branch_pretrain_cat, ref), att_in=(x_fea, ref), identity=x_fea, output_attention_weights=True) x_out = self.final_lr_conv(x__branch_pretrain_cat) x_out = self.upsample(x_out) x_out = self.final_hr_conv1(x_out) diff --git a/codes/models/archs/SwitchedResidualGenerator_arch.py b/codes/models/archs/SwitchedResidualGenerator_arch.py index 377e0f61..7bcf73ca 100644 --- a/codes/models/archs/SwitchedResidualGenerator_arch.py +++ b/codes/models/archs/SwitchedResidualGenerator_arch.py @@ -4,7 +4,7 @@ from switched_conv import BareConvSwitch, compute_attention_specificity, Attenti import torch.nn.functional as F import functools from collections import OrderedDict -from models.archs.arch_util import ConvBnLelu, ConvGnSilu, ExpansionBlock, ExpansionBlock2, ConjoinBlock +from models.archs.arch_util import ConvBnLelu, ConvGnSilu, ExpansionBlock, ExpansionBlock2, ConjoinBlock, ConvGnLelu from models.archs.RRDBNet_arch import ResidualDenseBlock_5C, RRDB from models.archs.spinenet_arch import SpineNet from switched_conv_util import save_attention_to_image_rgb @@ -131,16 +131,32 @@ class ReferenceImageBranch(nn.Module): return gather_2d(x, center_point // 8) +class AdaInConvBlock(nn.Module): + def __init__(self, reference_size, in_nc, out_nc, conv_block=ConvGnLelu): + super(AdaInConvBlock, self).__init__() + self.filter_conv = conv_block(in_nc, out_nc, activation=True, norm=False, bias=False) + self.ref_proc = nn.Linear(reference_size, reference_size) + self.ref_red = nn.Linear(reference_size, out_nc * 2) + self.feature_norm = torch.nn.InstanceNorm2d(out_nc) + self.style_norm = torch.nn.InstanceNorm1d(out_nc) + self.post_fuse_conv = conv_block(out_nc, out_nc, activation=False, norm=True, bias=True) + + def forward(self, x, ref): + x = self.feature_norm(self.filter_conv(x)) + ref = self.ref_proc(ref) + ref = self.ref_red(ref) + b, c = ref.shape + ref = self.style_norm(ref.view(b, 2, c // 2)) + x = x * ref[:, 0, :].unsqueeze(dim=2).unsqueeze(dim=3).expand(x.shape) + ref[:, 1, :].unsqueeze(dim=2).unsqueeze(dim=3).expand(x.shape) + return self.post_fuse_conv(x) + + # This is similar to ConvBasisMultiplexer, except that it takes a linear reference tensor as a second input to # provide better results. It also has fixed parameterization in several places class ReferencingConvMultiplexer(nn.Module): def __init__(self, input_channels, base_filters, multiplexer_channels, use_gn=True): super(ReferencingConvMultiplexer, self).__init__() - self.filter_conv = ConvGnSilu(input_channels, base_filters, bias=True) - self.ref_proc = nn.Linear(512, 512) - self.ref_red = nn.Linear(512, base_filters * 2) - self.feature_norm = torch.nn.InstanceNorm2d(base_filters) - self.style_norm = torch.nn.InstanceNorm1d(base_filters) + self.style_fuse = AdaInConvBlock(512, input_channels, base_filters, ConvGnSilu) self.reduction_blocks = nn.ModuleList([HalvingProcessingBlock(base_filters * 2 ** i) for i in range(3)]) reduction_filters = base_filters * 2 ** 3 @@ -155,13 +171,7 @@ class ReferencingConvMultiplexer(nn.Module): self.cbl3 = ConvGnSilu(cbl2_out, multiplexer_channels, bias=True, norm=False) def forward(self, x, ref): - # Start by fusing the reference vector and the input. Follows the ADAIn formula. - x = self.feature_norm(self.filter_conv(x)) - ref = self.ref_proc(ref) - ref = self.ref_red(ref) - b, c = ref.shape - ref = self.style_norm(ref.view(b, 2, c // 2)) - x = x * ref[:, 0, :].unsqueeze(dim=2).unsqueeze(dim=3).expand(x.shape) + ref[:, 1, :].unsqueeze(dim=2).unsqueeze(dim=3).expand(x.shape) + x = self.style_fuse(x, ref) reduction_identities = [] for b in self.reduction_blocks: @@ -221,26 +231,43 @@ class ConfigurableSwitchComputer(nn.Module): # depending on its needs. self.psc_scale = nn.Parameter(torch.full((1,), float(.1))) + + # Regarding inputs: it is acceptable to pass in a tuple/list as an input for (x), but the first element + # *must* be the actual parameter that gets fed through the network - it is assumed to be the identity. def forward(self, x, output_attention_weights=False, identity=None, att_in=None, fixed_scale=1): + if isinstance(x, tuple): + x1 = x[0] + else: + x1 = x + if att_in is None: att_in = x if identity is None: - identity = x + identity = x1 if self.add_noise: - rand_feature = torch.randn_like(x) * self.noise_scale - x = x + rand_feature + rand_feature = torch.randn_like(x1) * self.noise_scale + if isinstance(x, tuple): + x = (x1 + rand_feature,) + x[1:] + else: + x = x1 + rand_feature if self.pre_transform: - x = self.pre_transform(x) - xformed = [t.forward(x) for t in self.transforms] + if isinstance(x, tuple): + x = self.pre_transform(*x) + else: + x = self.pre_transform(x) + if isinstance(x, tuple): + xformed = [t.forward(*x) for t in self.transforms] + else: + xformed = [t.forward(x) for t in self.transforms] if isinstance(att_in, tuple): m = self.multiplexer(*att_in) else: m = self.multiplexer(att_in) - + # It is assumed that [xformed] and [m] are collapsed into tensors at this point. outputs, attention = self.switch(xformed, m, True) outputs = identity + outputs * self.switch_scale * fixed_scale outputs = outputs + self.post_switch_conv(outputs) * self.psc_scale * fixed_scale