forked from mrq/DL-Art-School
SRG2 architectural changes
This commit is contained in:
parent
9a1c3241f5
commit
3c31bea1ac
|
@ -115,6 +115,7 @@ class ConvBasisMultiplexer(nn.Module):
|
||||||
self.processing_blocks, self.output_filter_count = create_sequential_growing_processing_block(reduction_filters, growth, processing_depth)
|
self.processing_blocks, self.output_filter_count = create_sequential_growing_processing_block(reduction_filters, growth, processing_depth)
|
||||||
|
|
||||||
gap = self.output_filter_count - multiplexer_channels
|
gap = self.output_filter_count - multiplexer_channels
|
||||||
|
# Hey silly - if you're going to interpolate later, do it here instead. Then add some processing layers to let the model adjust it properly.
|
||||||
self.cbl1 = ConvBnSilu(self.output_filter_count, self.output_filter_count - (gap // 2), bn=use_bn, bias=False)
|
self.cbl1 = ConvBnSilu(self.output_filter_count, self.output_filter_count - (gap // 2), bn=use_bn, bias=False)
|
||||||
self.cbl2 = ConvBnSilu(self.output_filter_count - (gap // 2), self.output_filter_count - (3 * gap // 4), bn=use_bn, bias=False)
|
self.cbl2 = ConvBnSilu(self.output_filter_count - (gap // 2), self.output_filter_count - (3 * gap // 4), bn=use_bn, bias=False)
|
||||||
self.cbl3 = ConvBnSilu(self.output_filter_count - (3 * gap // 4), multiplexer_channels, bias=True)
|
self.cbl3 = ConvBnSilu(self.output_filter_count - (3 * gap // 4), multiplexer_channels, bias=True)
|
||||||
|
@ -152,18 +153,19 @@ class ConfigurableSwitchedResidualGenerator2(nn.Module):
|
||||||
add_scalable_noise_to_transforms=False):
|
add_scalable_noise_to_transforms=False):
|
||||||
super(ConfigurableSwitchedResidualGenerator2, self).__init__()
|
super(ConfigurableSwitchedResidualGenerator2, self).__init__()
|
||||||
switches = []
|
switches = []
|
||||||
self.initial_conv = ConvBnLelu(3, transformation_filters, bn=False)
|
self.initial_conv = ConvBnLelu(3, transformation_filters, bn=False, lelu=False, bias=True)
|
||||||
self.proc_conv = ConvBnLelu(transformation_filters, transformation_filters, bn=False)
|
self.sw_conv = ConvBnLelu(transformation_filters, transformation_filters, lelu=False, bias=True)
|
||||||
self.final_conv = ConvBnLelu(transformation_filters, 3, bn=False, lelu=False)
|
self.upconv1 = ConvBnLelu(transformation_filters, transformation_filters, bn=False, biasd=True)
|
||||||
|
self.upconv2 = ConvBnLelu(transformation_filters, transformation_filters, bn=False, bias=True)
|
||||||
|
self.hr_conv = ConvBnLelu(transformation_filters, transformation_filters, bn=False, bias=True)
|
||||||
|
self.final_conv = ConvBnLelu(transformation_filters, 3, bn=False, lelu=False, bias=True)
|
||||||
for filters, growth, sw_reduce, sw_proc, trans_count, kernel, layers in zip(switch_filters, switch_growths, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes, trans_layers):
|
for filters, growth, sw_reduce, sw_proc, trans_count, kernel, layers in zip(switch_filters, switch_growths, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes, trans_layers):
|
||||||
multiplx_fn = functools.partial(ConvBasisMultiplexer, transformation_filters, filters, growth, sw_reduce, sw_proc, trans_count)
|
multiplx_fn = functools.partial(ConvBasisMultiplexer, transformation_filters, filters, growth, sw_reduce, sw_proc, trans_count)
|
||||||
switches.append(ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
|
switches.append(ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
|
||||||
pre_transform_block=functools.partial(ConvBnLelu, transformation_filters, transformation_filters, bn=False, bias=False),
|
pre_transform_block=functools.partial(ConvBnLelu, transformation_filters, transformation_filters, bn=False, bias=False),
|
||||||
transform_block=functools.partial(MultiConvBlock, transformation_filters, transformation_filters, transformation_filters, kernel_size=kernel, depth=layers),
|
transform_block=functools.partial(MultiConvBlock, transformation_filters, transformation_filters + growth, transformation_filters, kernel_size=kernel, depth=layers),
|
||||||
transform_count=trans_count, init_temp=initial_temp, enable_negative_transforms=enable_negative_transforms,
|
transform_count=trans_count, init_temp=initial_temp, enable_negative_transforms=enable_negative_transforms,
|
||||||
add_scalable_noise_to_transforms=add_scalable_noise_to_transforms, init_scalar=1))
|
add_scalable_noise_to_transforms=add_scalable_noise_to_transforms, init_scalar=.2))
|
||||||
# Initialize the transforms with a lesser weight, since they are repeatedly added on to the resultant image.
|
|
||||||
initialize_weights([s.transforms for s in switches], .2 / len(switches))
|
|
||||||
|
|
||||||
self.switches = nn.ModuleList(switches)
|
self.switches = nn.ModuleList(switches)
|
||||||
self.transformation_counts = trans_counts
|
self.transformation_counts = trans_counts
|
||||||
|
@ -178,16 +180,18 @@ class ConfigurableSwitchedResidualGenerator2(nn.Module):
|
||||||
x = self.initial_conv(x)
|
x = self.initial_conv(x)
|
||||||
|
|
||||||
self.attentions = []
|
self.attentions = []
|
||||||
|
swx = x
|
||||||
for i, sw in enumerate(self.switches):
|
for i, sw in enumerate(self.switches):
|
||||||
x, att = sw.forward(x, True)
|
swx, att = sw.forward(swx, True)
|
||||||
self.attentions.append(att)
|
self.attentions.append(att)
|
||||||
|
x = swx + self.sw_conv(x)
|
||||||
|
|
||||||
if self.upsample_factor > 1:
|
assert x == 2 or x == 4
|
||||||
x = F.interpolate(x, scale_factor=self.upsample_factor, mode="nearest")
|
x = self.upconv1(F.interpolate(x, scale_factor=2, mode="nearest"))
|
||||||
|
if self.upsample_factor > 2:
|
||||||
x = self.proc_conv(x)
|
x = F.interpolate(x, scale_factor=2, mode="nearest")
|
||||||
x = self.final_conv(x)
|
x = self.upconv2(x)
|
||||||
return x,
|
return self.final_conv(self.hr_conv(x)),
|
||||||
|
|
||||||
def set_temperature(self, temp):
|
def set_temperature(self, temp):
|
||||||
[sw.set_temperature(temp) for sw in self.switches]
|
[sw.set_temperature(temp) for sw in self.switches]
|
||||||
|
|
Loading…
Reference in New Issue
Block a user