diff --git a/codes/models/archs/SPSR_arch.py b/codes/models/archs/SPSR_arch.py index af067fbb..86d0f07b 100644 --- a/codes/models/archs/SPSR_arch.py +++ b/codes/models/archs/SPSR_arch.py @@ -267,8 +267,7 @@ class SwitchedSpsr(nn.Module): # Grad branch self.get_g_nopadding = ImageGradientNoPadding() self.b_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False, bias=False) - self.sw_grad_mplex_converge = ConjoinBlock2(nf) - mplex_grad = functools.partial(ConvBasisMultiplexer, nf, nf, switch_reductions, + mplex_grad = functools.partial(ConvBasisMultiplexer, nf * 2, nf * 2, switch_reductions, switch_processing_layers, self.transformation_counts // 2, use_exp2=True) self.sw_grad = ConfigurableSwitchComputer(transformation_filters, mplex_grad, pre_transform_block=pretransform_fn, transform_block=transform_fn, @@ -312,8 +311,7 @@ class SwitchedSpsr(nn.Module): x_fea = self.feature_hr_conv2(x_fea) x_b_fea = self.b_fea_conv(x_grad) - grad_mplex_in = self.sw_grad_mplex_converge(x1, passthrough=x_b_fea) - x_grad, a3 = self.sw_grad(x_b_fea, att_in=grad_mplex_in, output_attention_weights=True) + x_grad, a3 = self.sw_grad(x_b_fea, att_in=torch.cat([x1, x_b_fea], dim=1), output_attention_weights=True) x_grad = self.grad_lr_conv(x_grad) x_grad = self.grad_hr_conv(x_grad) x_out_branch = self.upsample_grad(x_grad)