From 668bfbff6d34e3682020b0963954327f3695780b Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 10 Sep 2020 14:58:14 -0600 Subject: [PATCH] Back to best arch for spsr3 --- codes/models/archs/SPSR_arch.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/codes/models/archs/SPSR_arch.py b/codes/models/archs/SPSR_arch.py index 4abf92c0..d9c59d1f 100644 --- a/codes/models/archs/SPSR_arch.py +++ b/codes/models/archs/SPSR_arch.py @@ -437,13 +437,12 @@ class SwitchedSpsrWithRef2(nn.Module): # Join branch (grad+fea) self.ref_join4 = RefJoiner(nf) self.noise_ref_join_conjoin = ReferenceJoinBlock(nf, residual_weight_init_factor=.01, norm=False) - self.conjoin_sw = ConfigurableSwitchComputer(transformation_filters, - functools.partial(ModuleWithRef, nf, multiplx_fn), - pre_transform_block=None, - transform_block=functools.partial(ModuleWithRef, nf, transform_fn), - attention_norm=True, - transform_count=self.transformation_counts, init_temp=init_temperature, - add_scalable_noise_to_transforms=False) + self.conjoin_ref_join = ReferenceJoinBlock(nf, residual_weight_init_factor=.2, norm=False) + self.conjoin_sw = ConfigurableSwitchComputer(transformation_filters, multiplx_fn, + pre_transform_block=pretransform_fn, transform_block=transform_fn, + attention_norm=True, + transform_count=self.transformation_counts, init_temp=init_temperature, + add_scalable_noise_to_transforms=False) self.final_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=True) self.upsample = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=False, activation=True, bias=True) for _ in range(n_upscale)]) self.final_hr_conv1 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False, bias=True) @@ -481,7 +480,8 @@ class SwitchedSpsrWithRef2(nn.Module): x_out = x2 x_out = self.noise_ref_join_conjoin(x_out, torch.randn_like(x_out)) x_out = self.ref_join4(x_out, ref) - x_out, a4 = self.conjoin_sw((x_out, x_grad), True, identity=x2) + x_out = self.conjoin_ref_join(x_out, x_grad) + x_out, a4 = self.conjoin_sw(x_out, True, identity=x2) x_out = self.final_lr_conv(x_out) x_out = self.upsample(x_out) x_out = self.final_hr_conv1(x_out)