Back to best arch for spsr3
This commit is contained in:
parent
992b0a8d98
commit
668bfbff6d
|
@ -437,13 +437,12 @@ class SwitchedSpsrWithRef2(nn.Module):
|
|||
# Join branch (grad+fea)
|
||||
self.ref_join4 = RefJoiner(nf)
|
||||
self.noise_ref_join_conjoin = ReferenceJoinBlock(nf, residual_weight_init_factor=.01, norm=False)
|
||||
self.conjoin_sw = ConfigurableSwitchComputer(transformation_filters,
|
||||
functools.partial(ModuleWithRef, nf, multiplx_fn),
|
||||
pre_transform_block=None,
|
||||
transform_block=functools.partial(ModuleWithRef, nf, transform_fn),
|
||||
attention_norm=True,
|
||||
transform_count=self.transformation_counts, init_temp=init_temperature,
|
||||
add_scalable_noise_to_transforms=False)
|
||||
self.conjoin_ref_join = ReferenceJoinBlock(nf, residual_weight_init_factor=.2, norm=False)
|
||||
self.conjoin_sw = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
|
||||
pre_transform_block=pretransform_fn, transform_block=transform_fn,
|
||||
attention_norm=True,
|
||||
transform_count=self.transformation_counts, init_temp=init_temperature,
|
||||
add_scalable_noise_to_transforms=False)
|
||||
self.final_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=True)
|
||||
self.upsample = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=False, activation=True, bias=True) for _ in range(n_upscale)])
|
||||
self.final_hr_conv1 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False, bias=True)
|
||||
|
@ -481,7 +480,8 @@ class SwitchedSpsrWithRef2(nn.Module):
|
|||
x_out = x2
|
||||
x_out = self.noise_ref_join_conjoin(x_out, torch.randn_like(x_out))
|
||||
x_out = self.ref_join4(x_out, ref)
|
||||
x_out, a4 = self.conjoin_sw((x_out, x_grad), True, identity=x2)
|
||||
x_out = self.conjoin_ref_join(x_out, x_grad)
|
||||
x_out, a4 = self.conjoin_sw(x_out, True, identity=x2)
|
||||
x_out = self.final_lr_conv(x_out)
|
||||
x_out = self.upsample(x_out)
|
||||
x_out = self.final_hr_conv1(x_out)
|
||||
|
|
Loading…
Reference in New Issue
Block a user