From 0ffac391c1fcc14f9146846f65b82e9969ba8193 Mon Sep 17 00:00:00 2001 From: James Betker Date: Wed, 9 Sep 2020 11:17:07 -0600 Subject: [PATCH] SPSR with ref joining --- codes/models/archs/SPSR_arch.py | 34 +++++++++++++++++++++++++++++---- 1 file changed, 30 insertions(+), 4 deletions(-) diff --git a/codes/models/archs/SPSR_arch.py b/codes/models/archs/SPSR_arch.py index bb9bbbac..307d8078 100644 --- a/codes/models/archs/SPSR_arch.py +++ b/codes/models/archs/SPSR_arch.py @@ -490,6 +490,22 @@ class MultiplexerWithReducer(nn.Module): x = self.reduce(x) return self.mplex(x, ref) + +class RefJoiner(nn.Module): + def __init__(self, nf): + super(RefJoiner, self).__init__() + self.lin1 = nn.Linear(512, 256) + self.lin2 = nn.Linear(256, nf) + self.join = ReferenceJoinBlock(nf, residual_weight_init_factor=.1, norm=False) + + def forward(self, x, ref): + ref = self.lin1(ref) + ref = self.lin2(ref) + b, _, h, w = x.shape + ref = ref.view(b, -1, 1, 1) + return self.join(x, ref.repeat((1, 1, h, w))) + + class SwitchedSpsrWithRef2(nn.Module): def __init__(self, in_nc, out_nc, nf, xforms=8, upscale=4, init_temperature=10): super(SwitchedSpsrWithRef2, self).__init__() @@ -506,14 +522,18 @@ class SwitchedSpsrWithRef2(nn.Module): transformation_filters, kernel_size=3, depth=3, weight_init_factor=.1) + self.reference_processor = ReferenceImageBranch(transformation_filters) + # Feature branch self.model_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False) self.noise_ref_join = ReferenceJoinBlock(nf, residual_weight_init_factor=.1, norm=False) + self.ref_join1 = RefJoiner(nf) self.sw1 = ConfigurableSwitchComputer(transformation_filters, multiplx_fn, pre_transform_block=pretransform_fn, transform_block=transform_fn, attention_norm=True, transform_count=self.transformation_counts, init_temp=init_temperature, add_scalable_noise_to_transforms=False) + self.ref_join2 = RefJoiner(nf) self.sw2 = ConfigurableSwitchComputer(transformation_filters, multiplx_fn, pre_transform_block=pretransform_fn, transform_block=transform_fn, attention_norm=True, @@ -536,7 +556,7 @@ class SwitchedSpsrWithRef2(nn.Module): self.upsample_grad = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=False, activation=True, bias=False) for _ in range(n_upscale)]) self.grad_branch_output_conv = ConvGnLelu(nf, out_nc, kernel_size=1, norm=False, activation=False, bias=True) - # Join branch (grad+fea + # Join branch (grad+fea) self.conjoin_ref_join = ReferenceJoinBlock(nf, residual_weight_init_factor=.2, norm=False) self.conjoin_sw = ConfigurableSwitchComputer(transformation_filters, multiplx_fn, pre_transform_block=pretransform_fn, transform_block=transform_fn, @@ -546,7 +566,7 @@ class SwitchedSpsrWithRef2(nn.Module): self.final_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=True, activation=False) self.upsample = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=True, activation=True, bias=False) for _ in range(n_upscale)]) self.final_hr_conv1 = ConvGnLelu(nf, nf, kernel_size=3, norm=True, activation=True, bias=False) - self.final_hr_conv2 = ConvGnLelu(nf, out_nc, kernel_size=3, norm=False, activation=False, bias=True) + self.final_hr_conv2 = ConvGnLelu(nf, out_nc, kernel_size=3, norm=False, activation=False, bias=False) self.switches = [self.sw1, self.sw2, self.sw_grad, self.conjoin_sw] self.attentions = None self.init_temperature = init_temperature @@ -554,11 +574,16 @@ class SwitchedSpsrWithRef2(nn.Module): def forward(self, x, ref, center_coord): x_grad = self.get_g_nopadding(x) + ref = self.reference_processor(ref, center_coord) x = self.model_fea_conv(x) x = self.noise_ref_join(x, torch.randn_like(x)) + x = self.ref_join1(x, ref) x1, a1 = self.sw1(x, True) - x2, a2 = self.sw2(x, True) + + x2 = x1 + x2 = self.ref_join2(x2, ref) + x2, a2 = self.sw2(x2, True) x_fea = self.feature_lr_conv(x2) x_fea = self.feature_lr_conv2(x_fea) @@ -570,7 +595,8 @@ class SwitchedSpsrWithRef2(nn.Module): x_grad_out = self.upsample_grad(x_grad) x_grad_out = self.grad_branch_output_conv(x_grad_out) - x_out = self.conjoin_ref_join(x_fea, x_grad) + x_out = x_fea + x_out = self.conjoin_ref_join(x_out, x_grad) x_out, a4 = self.conjoin_sw(x_out, True) x_out = self.final_lr_conv(x_out) x_out = self.upsample(x_out)