Back to best arch for spsr3

This commit is contained in:
James Betker 2020-09-10 14:58:14 -06:00
parent 992b0a8d98
commit 668bfbff6d

View File

@ -437,13 +437,12 @@ class SwitchedSpsrWithRef2(nn.Module):
# Join branch (grad+fea)
self.ref_join4 = RefJoiner(nf)
self.noise_ref_join_conjoin = ReferenceJoinBlock(nf, residual_weight_init_factor=.01, norm=False)
self.conjoin_sw = ConfigurableSwitchComputer(transformation_filters,
functools.partial(ModuleWithRef, nf, multiplx_fn),
pre_transform_block=None,
transform_block=functools.partial(ModuleWithRef, nf, transform_fn),
attention_norm=True,
transform_count=self.transformation_counts, init_temp=init_temperature,
add_scalable_noise_to_transforms=False)
self.conjoin_ref_join = ReferenceJoinBlock(nf, residual_weight_init_factor=.2, norm=False)
self.conjoin_sw = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
pre_transform_block=pretransform_fn, transform_block=transform_fn,
attention_norm=True,
transform_count=self.transformation_counts, init_temp=init_temperature,
add_scalable_noise_to_transforms=False)
self.final_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=True)
self.upsample = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=False, activation=True, bias=True) for _ in range(n_upscale)])
self.final_hr_conv1 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False, bias=True)
@ -481,7 +480,8 @@ class SwitchedSpsrWithRef2(nn.Module):
x_out = x2
x_out = self.noise_ref_join_conjoin(x_out, torch.randn_like(x_out))
x_out = self.ref_join4(x_out, ref)
x_out, a4 = self.conjoin_sw((x_out, x_grad), True, identity=x2)
x_out = self.conjoin_ref_join(x_out, x_grad)
x_out, a4 = self.conjoin_sw(x_out, True, identity=x2)
x_out = self.final_lr_conv(x_out)
x_out = self.upsample(x_out)
x_out = self.final_hr_conv1(x_out)