Invert ResGen2 to operate in LR space
This commit is contained in:
parent
e07d8abafb
commit
87f1e9c56f
|
@ -316,7 +316,8 @@ class ConfigurableSwitchedResidualGenerator2(nn.Module):
|
||||||
switches = []
|
switches = []
|
||||||
post_switch_proc = []
|
post_switch_proc = []
|
||||||
self.initial_conv = ConvBnLelu(3, transformation_filters, bn=False)
|
self.initial_conv = ConvBnLelu(3, transformation_filters, bn=False)
|
||||||
self.final_conv = ConvBnLelu(transformation_filters, 3, bn=False)
|
self.proc_conv = ConvBnLelu(transformation_filters, transformation_filters, bn=False)
|
||||||
|
self.final_conv = ConvBnLelu(transformation_filters, 3, bn=False, lelu=False)
|
||||||
for filters, growth, sw_reduce, sw_proc, trans_count, kernel, layers in zip(switch_filters, switch_growths, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes, trans_layers):
|
for filters, growth, sw_reduce, sw_proc, trans_count, kernel, layers in zip(switch_filters, switch_growths, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes, trans_layers):
|
||||||
multiplx_fn = functools.partial(ConvBasisMultiplexer, transformation_filters, filters, growth, sw_reduce, sw_proc, trans_count)
|
multiplx_fn = functools.partial(ConvBasisMultiplexer, transformation_filters, filters, growth, sw_reduce, sw_proc, trans_count)
|
||||||
switches.append(ConfigurableSwitchComputer(multiplx_fn, functools.partial(MultiConvBlock, transformation_filters, transformation_filters, transformation_filters, kernel_size=kernel, depth=layers), trans_count, initial_temp, enable_negative_transforms=enable_negative_transforms, add_scalable_noise_to_transforms=add_scalable_noise_to_transforms))
|
switches.append(ConfigurableSwitchComputer(multiplx_fn, functools.partial(MultiConvBlock, transformation_filters, transformation_filters, transformation_filters, kernel_size=kernel, depth=layers), trans_count, initial_temp, enable_negative_transforms=enable_negative_transforms, add_scalable_noise_to_transforms=add_scalable_noise_to_transforms))
|
||||||
|
@ -346,6 +347,12 @@ class ConfigurableSwitchedResidualGenerator2(nn.Module):
|
||||||
x = x + sw_out
|
x = x + sw_out
|
||||||
x = x + conv(x)
|
x = x + conv(x)
|
||||||
|
|
||||||
|
# This network is entirely a "repair" network and operates on full-resolution images. Upsample first if that
|
||||||
|
# is called for, then repair.
|
||||||
|
if self.upsample_factor > 1:
|
||||||
|
x = F.interpolate(x, scale_factor=self.upsample_factor, mode="nearest")
|
||||||
|
|
||||||
|
x = self.proc_conv(x)
|
||||||
x = self.final_conv(x)
|
x = self.final_conv(x)
|
||||||
return x,
|
return x,
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user