Undo early dim reduction on grad branch for SPSR_arch

This commit is contained in:
James Betker 2020-08-14 16:23:42 -06:00
parent 2d205f52ac
commit 868d0aa442

View File

@ -267,8 +267,7 @@ class SwitchedSpsr(nn.Module):
# Grad branch # Grad branch
self.get_g_nopadding = ImageGradientNoPadding() self.get_g_nopadding = ImageGradientNoPadding()
self.b_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False, bias=False) self.b_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False, bias=False)
self.sw_grad_mplex_converge = ConjoinBlock2(nf) mplex_grad = functools.partial(ConvBasisMultiplexer, nf * 2, nf * 2, switch_reductions,
mplex_grad = functools.partial(ConvBasisMultiplexer, nf, nf, switch_reductions,
switch_processing_layers, self.transformation_counts // 2, use_exp2=True) switch_processing_layers, self.transformation_counts // 2, use_exp2=True)
self.sw_grad = ConfigurableSwitchComputer(transformation_filters, mplex_grad, self.sw_grad = ConfigurableSwitchComputer(transformation_filters, mplex_grad,
pre_transform_block=pretransform_fn, transform_block=transform_fn, pre_transform_block=pretransform_fn, transform_block=transform_fn,
@ -312,8 +311,7 @@ class SwitchedSpsr(nn.Module):
x_fea = self.feature_hr_conv2(x_fea) x_fea = self.feature_hr_conv2(x_fea)
x_b_fea = self.b_fea_conv(x_grad) x_b_fea = self.b_fea_conv(x_grad)
grad_mplex_in = self.sw_grad_mplex_converge(x1, passthrough=x_b_fea) x_grad, a3 = self.sw_grad(x_b_fea, att_in=torch.cat([x1, x_b_fea], dim=1), output_attention_weights=True)
x_grad, a3 = self.sw_grad(x_b_fea, att_in=grad_mplex_in, output_attention_weights=True)
x_grad = self.grad_lr_conv(x_grad) x_grad = self.grad_lr_conv(x_grad)
x_grad = self.grad_hr_conv(x_grad) x_grad = self.grad_hr_conv(x_grad)
x_out_branch = self.upsample_grad(x_grad) x_out_branch = self.upsample_grad(x_grad)