Undo early dim reduction on grad branch for SPSR_arch
This commit is contained in:
parent
2d205f52ac
commit
868d0aa442
|
@ -267,8 +267,7 @@ class SwitchedSpsr(nn.Module):
|
|||
# Grad branch
|
||||
self.get_g_nopadding = ImageGradientNoPadding()
|
||||
self.b_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False, bias=False)
|
||||
self.sw_grad_mplex_converge = ConjoinBlock2(nf)
|
||||
mplex_grad = functools.partial(ConvBasisMultiplexer, nf, nf, switch_reductions,
|
||||
mplex_grad = functools.partial(ConvBasisMultiplexer, nf * 2, nf * 2, switch_reductions,
|
||||
switch_processing_layers, self.transformation_counts // 2, use_exp2=True)
|
||||
self.sw_grad = ConfigurableSwitchComputer(transformation_filters, mplex_grad,
|
||||
pre_transform_block=pretransform_fn, transform_block=transform_fn,
|
||||
|
@ -312,8 +311,7 @@ class SwitchedSpsr(nn.Module):
|
|||
x_fea = self.feature_hr_conv2(x_fea)
|
||||
|
||||
x_b_fea = self.b_fea_conv(x_grad)
|
||||
grad_mplex_in = self.sw_grad_mplex_converge(x1, passthrough=x_b_fea)
|
||||
x_grad, a3 = self.sw_grad(x_b_fea, att_in=grad_mplex_in, output_attention_weights=True)
|
||||
x_grad, a3 = self.sw_grad(x_b_fea, att_in=torch.cat([x1, x_b_fea], dim=1), output_attention_weights=True)
|
||||
x_grad = self.grad_lr_conv(x_grad)
|
||||
x_grad = self.grad_hr_conv(x_grad)
|
||||
x_out_branch = self.upsample_grad(x_grad)
|
||||
|
|
Loading…
Reference in New Issue
Block a user