forked from mrq/DL-Art-School
More SRG2 adjustments..
This commit is contained in:
parent
086b2f0570
commit
0acad81035
|
@ -54,6 +54,46 @@ def create_sequential_growing_processing_block(filters_init, filter_growth, num_
|
|||
return nn.Sequential(*convs), current_filters
|
||||
|
||||
|
||||
class ConvBasisMultiplexer(nn.Module):
|
||||
def __init__(self, input_channels, base_filters, growth, reductions, processing_depth, multiplexer_channels, use_bn=True):
|
||||
super(ConvBasisMultiplexer, self).__init__()
|
||||
self.filter_conv = ConvBnSilu(input_channels, base_filters, bias=True)
|
||||
self.reduction_blocks = nn.Sequential(OrderedDict([('block%i:' % (i,), HalvingProcessingBlock(base_filters * 2 ** i)) for i in range(reductions)]))
|
||||
reduction_filters = base_filters * 2 ** reductions
|
||||
self.processing_blocks, self.output_filter_count = create_sequential_growing_processing_block(reduction_filters, growth, processing_depth)
|
||||
|
||||
gap = self.output_filter_count - multiplexer_channels
|
||||
# Hey silly - if you're going to interpolate later, do it here instead. Then add some processing layers to let the model adjust it properly.
|
||||
self.cbl1 = ConvBnSilu(self.output_filter_count, self.output_filter_count - (gap // 2), bn=use_bn, bias=False)
|
||||
self.cbl2 = ConvBnSilu(self.output_filter_count - (gap // 2), self.output_filter_count - (3 * gap // 4), bn=use_bn, bias=False)
|
||||
self.cbl3 = ConvBnSilu(self.output_filter_count - (3 * gap // 4), multiplexer_channels, bias=True)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.filter_conv(x)
|
||||
x = self.reduction_blocks(x)
|
||||
x = self.processing_blocks(x)
|
||||
x = self.cbl1(x)
|
||||
x = self.cbl2(x)
|
||||
x = self.cbl3(x)
|
||||
return x
|
||||
|
||||
|
||||
class SpineNetMultiplexer(nn.Module):
|
||||
def __init__(self, input_channels, transform_count):
|
||||
super(SpineNetMultiplexer, self).__init__()
|
||||
self.backbone = SpineNet('49', in_channels=input_channels)
|
||||
self.rdc1 = ConvBnSilu(256, 128, kernel_size=3, bias=False)
|
||||
self.rdc2 = ConvBnSilu(128, 64, kernel_size=3, bias=False)
|
||||
self.rdc3 = ConvBnSilu(64, transform_count, bias=False, bn=False, relu=False)
|
||||
|
||||
def forward(self, x):
|
||||
spine = self.backbone(x)
|
||||
feat = self.rdc1(spine[0])
|
||||
feat = self.rdc2(feat)
|
||||
feat = self.rdc3(feat)
|
||||
return feat
|
||||
|
||||
|
||||
class ConfigurableSwitchComputer(nn.Module):
|
||||
def __init__(self, base_filters, multiplexer_net, pre_transform_block, transform_block, transform_count, init_temp=20,
|
||||
enable_negative_transforms=False, add_scalable_noise_to_transforms=False, init_scalar=1):
|
||||
|
@ -106,46 +146,6 @@ class ConfigurableSwitchComputer(nn.Module):
|
|||
self.switch.set_attention_temperature(temp)
|
||||
|
||||
|
||||
class ConvBasisMultiplexer(nn.Module):
|
||||
def __init__(self, input_channels, base_filters, growth, reductions, processing_depth, multiplexer_channels, use_bn=True):
|
||||
super(ConvBasisMultiplexer, self).__init__()
|
||||
self.filter_conv = ConvBnSilu(input_channels, base_filters, bias=True)
|
||||
self.reduction_blocks = nn.Sequential(OrderedDict([('block%i:' % (i,), HalvingProcessingBlock(base_filters * 2 ** i)) for i in range(reductions)]))
|
||||
reduction_filters = base_filters * 2 ** reductions
|
||||
self.processing_blocks, self.output_filter_count = create_sequential_growing_processing_block(reduction_filters, growth, processing_depth)
|
||||
|
||||
gap = self.output_filter_count - multiplexer_channels
|
||||
# Hey silly - if you're going to interpolate later, do it here instead. Then add some processing layers to let the model adjust it properly.
|
||||
self.cbl1 = ConvBnSilu(self.output_filter_count, self.output_filter_count - (gap // 2), bn=use_bn, bias=False)
|
||||
self.cbl2 = ConvBnSilu(self.output_filter_count - (gap // 2), self.output_filter_count - (3 * gap // 4), bn=use_bn, bias=False)
|
||||
self.cbl3 = ConvBnSilu(self.output_filter_count - (3 * gap // 4), multiplexer_channels, bias=True)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.filter_conv(x)
|
||||
x = self.reduction_blocks(x)
|
||||
x = self.processing_blocks(x)
|
||||
x = self.cbl1(x)
|
||||
x = self.cbl2(x)
|
||||
x = self.cbl3(x)
|
||||
return x
|
||||
|
||||
|
||||
class SpineNetMultiplexer(nn.Module):
|
||||
def __init__(self, input_channels, transform_count):
|
||||
super(SpineNetMultiplexer, self).__init__()
|
||||
self.backbone = SpineNet('49', in_channels=input_channels)
|
||||
self.rdc1 = ConvBnSilu(256, 128, kernel_size=3, bias=False)
|
||||
self.rdc2 = ConvBnSilu(128, 64, kernel_size=3, bias=False)
|
||||
self.rdc3 = ConvBnSilu(64, transform_count, bias=False, bn=False, relu=False)
|
||||
|
||||
def forward(self, x):
|
||||
spine = self.backbone(x)
|
||||
feat = self.rdc1(spine[0])
|
||||
feat = self.rdc2(feat)
|
||||
feat = self.rdc3(feat)
|
||||
return feat
|
||||
|
||||
|
||||
class ConfigurableSwitchedResidualGenerator2(nn.Module):
|
||||
def __init__(self, switch_filters, switch_growths, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes,
|
||||
trans_layers, transformation_filters, initial_temp=20, final_temperature_step=50000, heightened_temp_min=1,
|
||||
|
@ -165,7 +165,7 @@ class ConfigurableSwitchedResidualGenerator2(nn.Module):
|
|||
pre_transform_block=functools.partial(ConvBnLelu, transformation_filters, transformation_filters, bn=False, bias=False),
|
||||
transform_block=functools.partial(MultiConvBlock, transformation_filters, transformation_filters + growth, transformation_filters, kernel_size=kernel, depth=layers),
|
||||
transform_count=trans_count, init_temp=initial_temp, enable_negative_transforms=enable_negative_transforms,
|
||||
add_scalable_noise_to_transforms=add_scalable_noise_to_transforms, init_scalar=.2))
|
||||
add_scalable_noise_to_transforms=add_scalable_noise_to_transforms, init_scalar=.1))
|
||||
|
||||
self.switches = nn.ModuleList(switches)
|
||||
self.transformation_counts = trans_counts
|
||||
|
@ -181,11 +181,9 @@ class ConfigurableSwitchedResidualGenerator2(nn.Module):
|
|||
x = self.initial_conv(x)
|
||||
|
||||
self.attentions = []
|
||||
swx = x
|
||||
for i, sw in enumerate(self.switches):
|
||||
swx, att = sw.forward(swx, True)
|
||||
x, att = sw.forward(x, True)
|
||||
self.attentions.append(att)
|
||||
x = swx + self.sw_conv(x)
|
||||
|
||||
x = self.upconv1(F.interpolate(x, scale_factor=2, mode="nearest"))
|
||||
if self.upsample_factor > 2:
|
||||
|
|
Loading…
Reference in New Issue
Block a user