From 992b0a8d982aff4cdd64225dba908b9c46a95b24 Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 10 Sep 2020 09:11:37 -0600 Subject: [PATCH] spsr3 with conjoin stage as part of the switch --- codes/models/archs/SPSR_arch.py | 43 +++++++++++++++++++++------------ 1 file changed, 27 insertions(+), 16 deletions(-) diff --git a/codes/models/archs/SPSR_arch.py b/codes/models/archs/SPSR_arch.py index 29cae50d..4abf92c0 100644 --- a/codes/models/archs/SPSR_arch.py +++ b/codes/models/archs/SPSR_arch.py @@ -371,6 +371,17 @@ class RefJoiner(nn.Module): return self.join(x, ref.repeat((1, 1, h, w))) +class ModuleWithRef(nn.Module): + def __init__(self, nf, mcnv, *args): + super(ModuleWithRef, self).__init__() + self.join = ReferenceJoinBlock(nf, residual_weight_init_factor=.2, norm=False) + self.multi = mcnv(*args) + + def forward(self, x, ref): + out = self.join(x, ref) + return self.multi(out) + + class SwitchedSpsrWithRef2(nn.Module): def __init__(self, in_nc, out_nc, nf, xforms=8, upscale=4, init_temperature=10): super(SwitchedSpsrWithRef2, self).__init__() @@ -426,12 +437,13 @@ class SwitchedSpsrWithRef2(nn.Module): # Join branch (grad+fea) self.ref_join4 = RefJoiner(nf) self.noise_ref_join_conjoin = ReferenceJoinBlock(nf, residual_weight_init_factor=.01, norm=False) - self.conjoin_ref_join = ReferenceJoinBlock(nf, residual_weight_init_factor=.2, norm=False) - self.conjoin_sw = ConfigurableSwitchComputer(transformation_filters, multiplx_fn, - pre_transform_block=pretransform_fn, transform_block=transform_fn, - attention_norm=True, - transform_count=self.transformation_counts, init_temp=init_temperature, - add_scalable_noise_to_transforms=False) + self.conjoin_sw = ConfigurableSwitchComputer(transformation_filters, + functools.partial(ModuleWithRef, nf, multiplx_fn), + pre_transform_block=None, + transform_block=functools.partial(ModuleWithRef, nf, transform_fn), + attention_norm=True, + transform_count=self.transformation_counts, init_temp=init_temperature, + add_scalable_noise_to_transforms=False) self.final_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=True) self.upsample = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=False, activation=True, bias=True) for _ in range(n_upscale)]) self.final_hr_conv1 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False, bias=True) @@ -443,22 +455,22 @@ class SwitchedSpsrWithRef2(nn.Module): def forward(self, x, ref, center_coord): x_grad = self.get_g_nopadding(x) - #ref = self.reference_processor(ref, center_coord) + ref = self.reference_processor(ref, center_coord) x = self.model_fea_conv(x) x1 = x - #x1 = self.ref_join1(x1, ref) + x1 = self.ref_join1(x1, ref) x1, a1 = self.sw1(x1, True, identity=x) x2 = x1 - #x2 = self.noise_ref_join(x2, torch.randn_like(x2)) - #x2 = self.ref_join2(x2, ref) + x2 = self.noise_ref_join(x2, torch.randn_like(x2)) + x2 = self.ref_join2(x2, ref) x2, a2 = self.sw2(x2, True, identity=x1) x_grad = self.grad_conv(x_grad) x_grad_identity = x_grad - #x_grad = self.noise_ref_join_grad(x_grad, torch.randn_like(x_grad)) - #x_grad = self.ref_join3(x_grad, ref) + x_grad = self.noise_ref_join_grad(x_grad, torch.randn_like(x_grad)) + x_grad = self.ref_join3(x_grad, ref) x_grad = self.grad_ref_join(x_grad, x1) x_grad, a3 = self.sw_grad(x_grad, True, identity=x_grad_identity) x_grad = self.grad_lr_conv(x_grad) @@ -467,10 +479,9 @@ class SwitchedSpsrWithRef2(nn.Module): x_grad_out = self.grad_branch_output_conv(x_grad_out) x_out = x2 - #x_out = self.noise_ref_join_conjoin(x_out, torch.randn_like(x_out)) - #x_out = self.ref_join4(x_out, ref) - x_out = self.conjoin_ref_join(x_out, x_grad) - x_out, a4 = self.conjoin_sw(x_out, True, identity=x2) + x_out = self.noise_ref_join_conjoin(x_out, torch.randn_like(x_out)) + x_out = self.ref_join4(x_out, ref) + x_out, a4 = self.conjoin_sw((x_out, x_grad), True, identity=x2) x_out = self.final_lr_conv(x_out) x_out = self.upsample(x_out) x_out = self.final_hr_conv1(x_out)