Move ExpansionBlock to arch_util
Also makes all processing blocks have a conformant signature. Alters ExpansionBlock to perform a processing conv on the passthrough before the conjoin operation - this will break backwards compatibilty with SRG2.
This commit is contained in:
parent
5e8b52f34c
commit
33ca3832e1
|
@ -83,9 +83,9 @@ class Constrictor(nn.Module):
|
||||||
assert(filters > output_filters)
|
assert(filters > output_filters)
|
||||||
gap = filters - output_filters
|
gap = filters - output_filters
|
||||||
gap_div_4 = int(gap / 4)
|
gap_div_4 = int(gap / 4)
|
||||||
self.cbl1 = ConvBnRelu(filters, filters - (gap_div_4 * 2), kernel_size=1, bn=True, bias=True)
|
self.cbl1 = ConvBnRelu(filters, filters - (gap_div_4 * 2), kernel_size=1, norm=True, bias=True)
|
||||||
self.cbl2 = ConvBnRelu(filters - (gap_div_4 * 2), filters - (gap_div_4 * 3), kernel_size=1, bn=True, bias=False)
|
self.cbl2 = ConvBnRelu(filters - (gap_div_4 * 2), filters - (gap_div_4 * 3), kernel_size=1, norm=True, bias=False)
|
||||||
self.cbl3 = ConvBnRelu(filters - (gap_div_4 * 3), output_filters, kernel_size=1, relu=False, bn=False, bias=False)
|
self.cbl3 = ConvBnRelu(filters - (gap_div_4 * 3), output_filters, kernel_size=1, activation=False, norm=False, bias=False)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.cbl1(x)
|
x = self.cbl1(x)
|
||||||
|
@ -134,10 +134,10 @@ class NestedSwitchComputer(nn.Module):
|
||||||
filters.append(current_filters)
|
filters.append(current_filters)
|
||||||
reduce = True
|
reduce = True
|
||||||
|
|
||||||
self.multiplexer_init_conv = ConvBnLelu(transform_filters, switch_base_filters, kernel_size=7, lelu=False, bn=False)
|
self.multiplexer_init_conv = ConvBnLelu(transform_filters, switch_base_filters, kernel_size=7, activation=False, norm=False)
|
||||||
self.processing_trunk = nn.ModuleList(processing_trunk)
|
self.processing_trunk = nn.ModuleList(processing_trunk)
|
||||||
self.switch = RecursiveSwitchedTransform(transform_filters, filters, nesting_depth-1, transforms_at_leaf, trans_kernel_size, trans_num_layers-1, trans_scale_init, initial_temp=initial_temp, add_scalable_noise_to_transforms=add_scalable_noise_to_transforms)
|
self.switch = RecursiveSwitchedTransform(transform_filters, filters, nesting_depth-1, transforms_at_leaf, trans_kernel_size, trans_num_layers-1, trans_scale_init, initial_temp=initial_temp, add_scalable_noise_to_transforms=add_scalable_noise_to_transforms)
|
||||||
self.anneal = ConvBnLelu(transform_filters, transform_filters, kernel_size=1, bn=False)
|
self.anneal = ConvBnLelu(transform_filters, transform_filters, kernel_size=1, norm=False)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
feed_forward = x
|
feed_forward = x
|
||||||
|
@ -161,9 +161,9 @@ class NestedSwitchedGenerator(nn.Module):
|
||||||
trans_layers, transformation_filters, initial_temp=20, final_temperature_step=50000, heightened_temp_min=1,
|
trans_layers, transformation_filters, initial_temp=20, final_temperature_step=50000, heightened_temp_min=1,
|
||||||
heightened_final_step=50000, upsample_factor=1, add_scalable_noise_to_transforms=False):
|
heightened_final_step=50000, upsample_factor=1, add_scalable_noise_to_transforms=False):
|
||||||
super(NestedSwitchedGenerator, self).__init__()
|
super(NestedSwitchedGenerator, self).__init__()
|
||||||
self.initial_conv = ConvBnLelu(3, transformation_filters, kernel_size=7, lelu=False, bn=False)
|
self.initial_conv = ConvBnLelu(3, transformation_filters, kernel_size=7, activation=False, norm=False)
|
||||||
self.proc_conv = ConvBnLelu(transformation_filters, transformation_filters, bn=False)
|
self.proc_conv = ConvBnLelu(transformation_filters, transformation_filters, norm=False)
|
||||||
self.final_conv = ConvBnLelu(transformation_filters, 3, kernel_size=1, lelu=False, bn=False)
|
self.final_conv = ConvBnLelu(transformation_filters, 3, kernel_size=1, activation=False, norm=False)
|
||||||
|
|
||||||
switches = []
|
switches = []
|
||||||
for sw_reduce, sw_proc, trans_count, kernel, layers in zip(switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes, trans_layers):
|
for sw_reduce, sw_proc, trans_count, kernel, layers in zip(switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes, trans_layers):
|
||||||
|
|
|
@ -12,9 +12,9 @@ class MultiConvBlock(nn.Module):
|
||||||
assert depth >= 2
|
assert depth >= 2
|
||||||
super(MultiConvBlock, self).__init__()
|
super(MultiConvBlock, self).__init__()
|
||||||
self.noise_scale = nn.Parameter(torch.full((1,), fill_value=.01))
|
self.noise_scale = nn.Parameter(torch.full((1,), fill_value=.01))
|
||||||
self.bnconvs = nn.ModuleList([ConvBnLelu(filters_in, filters_mid, kernel_size, bn=bn, bias=False)] +
|
self.bnconvs = nn.ModuleList([ConvBnLelu(filters_in, filters_mid, kernel_size, norm=bn, bias=False)] +
|
||||||
[ConvBnLelu(filters_mid, filters_mid, kernel_size, bn=bn, bias=False) for i in range(depth-2)] +
|
[ConvBnLelu(filters_mid, filters_mid, kernel_size, norm=bn, bias=False) for i in range(depth - 2)] +
|
||||||
[ConvBnLelu(filters_mid, filters_out, kernel_size, lelu=False, bn=False, bias=False)])
|
[ConvBnLelu(filters_mid, filters_out, kernel_size, activation=False, norm=False, bias=False)])
|
||||||
self.scale = nn.Parameter(torch.full((1,), fill_value=scale_init))
|
self.scale = nn.Parameter(torch.full((1,), fill_value=scale_init))
|
||||||
self.bias = nn.Parameter(torch.zeros(1))
|
self.bias = nn.Parameter(torch.zeros(1))
|
||||||
|
|
||||||
|
@ -32,8 +32,8 @@ class MultiConvBlock(nn.Module):
|
||||||
class HalvingProcessingBlock(nn.Module):
|
class HalvingProcessingBlock(nn.Module):
|
||||||
def __init__(self, filters):
|
def __init__(self, filters):
|
||||||
super(HalvingProcessingBlock, self).__init__()
|
super(HalvingProcessingBlock, self).__init__()
|
||||||
self.bnconv1 = ConvBnSilu(filters, filters * 2, stride=2, bn=False, bias=False)
|
self.bnconv1 = ConvBnSilu(filters, filters * 2, stride=2, norm=False, bias=False)
|
||||||
self.bnconv2 = ConvBnSilu(filters * 2, filters * 2, bn=True, bias=False)
|
self.bnconv2 = ConvBnSilu(filters * 2, filters * 2, norm=True, bias=False)
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.bnconv1(x)
|
x = self.bnconv1(x)
|
||||||
return self.bnconv2(x)
|
return self.bnconv2(x)
|
||||||
|
@ -45,7 +45,7 @@ def create_sequential_growing_processing_block(filters_init, filter_growth, num_
|
||||||
convs = []
|
convs = []
|
||||||
current_filters = filters_init
|
current_filters = filters_init
|
||||||
for i in range(num_convs):
|
for i in range(num_convs):
|
||||||
convs.append(ConvBnSilu(current_filters, current_filters + filter_growth, bn=True, bias=False))
|
convs.append(ConvBnSilu(current_filters, current_filters + filter_growth, norm=True, bias=False))
|
||||||
current_filters += filter_growth
|
current_filters += filter_growth
|
||||||
return nn.Sequential(*convs), current_filters
|
return nn.Sequential(*convs), current_filters
|
||||||
|
|
||||||
|
@ -60,7 +60,7 @@ class SwitchComputer(nn.Module):
|
||||||
self.reduction_blocks = nn.ModuleList([HalvingProcessingBlock(filters * 2 ** i) for i in range(reduction_blocks)])
|
self.reduction_blocks = nn.ModuleList([HalvingProcessingBlock(filters * 2 ** i) for i in range(reduction_blocks)])
|
||||||
final_filters = filters * 2 ** reduction_blocks
|
final_filters = filters * 2 ** reduction_blocks
|
||||||
self.processing_blocks, final_filters = create_sequential_growing_processing_block(final_filters, growth, processing_blocks)
|
self.processing_blocks, final_filters = create_sequential_growing_processing_block(final_filters, growth, processing_blocks)
|
||||||
self.post_interpolate_decimate = ConvBnSilu(final_filters, filters, kernel_size=1, silu=False, bn=False)
|
self.post_interpolate_decimate = ConvBnSilu(final_filters, filters, kernel_size=1, activation=False, norm=False)
|
||||||
self.interpolate_process = ConvBnSilu(filters, filters)
|
self.interpolate_process = ConvBnSilu(filters, filters)
|
||||||
self.interpolate_process2 = ConvBnSilu(filters, filters)
|
self.interpolate_process2 = ConvBnSilu(filters, filters)
|
||||||
tc = transform_count
|
tc = transform_count
|
||||||
|
|
|
@ -4,7 +4,7 @@ from switched_conv import BareConvSwitch, compute_attention_specificity
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import functools
|
import functools
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from models.archs.arch_util import ConvBnLelu, ConvGnSilu
|
from models.archs.arch_util import ConvBnLelu, ConvGnSilu, ExpansionBlock
|
||||||
from models.archs.RRDBNet_arch import ResidualDenseBlock_5C
|
from models.archs.RRDBNet_arch import ResidualDenseBlock_5C
|
||||||
from models.archs.spinenet_arch import SpineNet
|
from models.archs.spinenet_arch import SpineNet
|
||||||
from switched_conv_util import save_attention_to_image
|
from switched_conv_util import save_attention_to_image
|
||||||
|
@ -15,9 +15,9 @@ class MultiConvBlock(nn.Module):
|
||||||
assert depth >= 2
|
assert depth >= 2
|
||||||
super(MultiConvBlock, self).__init__()
|
super(MultiConvBlock, self).__init__()
|
||||||
self.noise_scale = nn.Parameter(torch.full((1,), fill_value=.01))
|
self.noise_scale = nn.Parameter(torch.full((1,), fill_value=.01))
|
||||||
self.bnconvs = nn.ModuleList([ConvBnLelu(filters_in, filters_mid, kernel_size, bn=bn, bias=False, weight_init_factor=weight_init_factor)] +
|
self.bnconvs = nn.ModuleList([ConvBnLelu(filters_in, filters_mid, kernel_size, norm=bn, bias=False, weight_init_factor=weight_init_factor)] +
|
||||||
[ConvBnLelu(filters_mid, filters_mid, kernel_size, bn=bn, bias=False, weight_init_factor=weight_init_factor) for i in range(depth-2)] +
|
[ConvBnLelu(filters_mid, filters_mid, kernel_size, norm=bn, bias=False, weight_init_factor=weight_init_factor) for i in range(depth - 2)] +
|
||||||
[ConvBnLelu(filters_mid, filters_out, kernel_size, lelu=False, bn=False, bias=False, weight_init_factor=weight_init_factor)])
|
[ConvBnLelu(filters_mid, filters_out, kernel_size, activation=False, norm=False, bias=False, weight_init_factor=weight_init_factor)])
|
||||||
self.scale = nn.Parameter(torch.full((1,), fill_value=scale_init))
|
self.scale = nn.Parameter(torch.full((1,), fill_value=scale_init))
|
||||||
self.bias = nn.Parameter(torch.zeros(1))
|
self.bias = nn.Parameter(torch.zeros(1))
|
||||||
|
|
||||||
|
@ -35,28 +35,14 @@ class MultiConvBlock(nn.Module):
|
||||||
class HalvingProcessingBlock(nn.Module):
|
class HalvingProcessingBlock(nn.Module):
|
||||||
def __init__(self, filters):
|
def __init__(self, filters):
|
||||||
super(HalvingProcessingBlock, self).__init__()
|
super(HalvingProcessingBlock, self).__init__()
|
||||||
self.bnconv1 = ConvGnSilu(filters, filters * 2, stride=2, gn=False, bias=False)
|
self.bnconv1 = ConvGnSilu(filters, filters * 2, stride=2, norm=False, bias=False)
|
||||||
self.bnconv2 = ConvGnSilu(filters * 2, filters * 2, gn=True, bias=False)
|
self.bnconv2 = ConvGnSilu(filters * 2, filters * 2, norm=True, bias=False)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.bnconv1(x)
|
x = self.bnconv1(x)
|
||||||
return self.bnconv2(x)
|
return self.bnconv2(x)
|
||||||
|
|
||||||
|
|
||||||
class ExpansionBlock(nn.Module):
|
|
||||||
def __init__(self, filters):
|
|
||||||
super(ExpansionBlock, self).__init__()
|
|
||||||
self.decimate = ConvGnSilu(filters, filters // 2, kernel_size=1, bias=False, silu=False, gn=False)
|
|
||||||
self.conjoin = ConvGnSilu(filters, filters // 2, kernel_size=3, bias=True, silu=False, gn=True)
|
|
||||||
self.process = ConvGnSilu(filters // 2, filters // 2, kernel_size=3, bias=False, silu=True, gn=True)
|
|
||||||
|
|
||||||
def forward(self, input, passthrough):
|
|
||||||
x = F.interpolate(input, scale_factor=2, mode="nearest")
|
|
||||||
x = self.decimate(x)
|
|
||||||
x = self.conjoin(torch.cat([x, passthrough], dim=1))
|
|
||||||
return self.process(x)
|
|
||||||
|
|
||||||
|
|
||||||
# This is a classic u-net architecture with the goal of assigning each individual pixel an individual transform
|
# This is a classic u-net architecture with the goal of assigning each individual pixel an individual transform
|
||||||
# switching set.
|
# switching set.
|
||||||
class ConvBasisMultiplexer(nn.Module):
|
class ConvBasisMultiplexer(nn.Module):
|
||||||
|
@ -70,10 +56,10 @@ class ConvBasisMultiplexer(nn.Module):
|
||||||
|
|
||||||
gap = base_filters - multiplexer_channels
|
gap = base_filters - multiplexer_channels
|
||||||
cbl1_out = ((base_filters - (gap // 2)) // 4) * 4 # Must be multiples of 4 to use with group norm.
|
cbl1_out = ((base_filters - (gap // 2)) // 4) * 4 # Must be multiples of 4 to use with group norm.
|
||||||
self.cbl1 = ConvGnSilu(base_filters, cbl1_out, gn=use_gn, bias=False, num_groups=4)
|
self.cbl1 = ConvGnSilu(base_filters, cbl1_out, norm=use_gn, bias=False, num_groups=4)
|
||||||
cbl2_out = ((base_filters - (3 * gap // 4)) // 4) * 4
|
cbl2_out = ((base_filters - (3 * gap // 4)) // 4) * 4
|
||||||
self.cbl2 = ConvGnSilu(cbl1_out, cbl2_out, gn=use_gn, bias=False, num_groups=4)
|
self.cbl2 = ConvGnSilu(cbl1_out, cbl2_out, norm=use_gn, bias=False, num_groups=4)
|
||||||
self.cbl3 = ConvGnSilu(cbl2_out, multiplexer_channels, bias=True, gn=False)
|
self.cbl3 = ConvGnSilu(cbl2_out, multiplexer_channels, bias=True, norm=False)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.filter_conv(x)
|
x = self.filter_conv(x)
|
||||||
|
@ -109,11 +95,11 @@ class BackboneMultiplexer(nn.Module):
|
||||||
self.backbone = backbone
|
self.backbone = backbone
|
||||||
self.proc = nn.Sequential(ConvGnSilu(256, 256, kernel_size=3, bias=True),
|
self.proc = nn.Sequential(ConvGnSilu(256, 256, kernel_size=3, bias=True),
|
||||||
ConvGnSilu(256, 256, kernel_size=3, bias=False))
|
ConvGnSilu(256, 256, kernel_size=3, bias=False))
|
||||||
self.up1 = nn.Sequential(ConvGnSilu(256, 128, kernel_size=3, bias=False, gn=False, silu=False),
|
self.up1 = nn.Sequential(ConvGnSilu(256, 128, kernel_size=3, bias=False, norm=False, activation=False),
|
||||||
ConvGnSilu(128, 128, kernel_size=3, bias=False))
|
ConvGnSilu(128, 128, kernel_size=3, bias=False))
|
||||||
self.up2 = nn.Sequential(ConvGnSilu(128, 64, kernel_size=3, bias=False, gn=False, silu=False),
|
self.up2 = nn.Sequential(ConvGnSilu(128, 64, kernel_size=3, bias=False, norm=False, activation=False),
|
||||||
ConvGnSilu(64, 64, kernel_size=3, bias=False))
|
ConvGnSilu(64, 64, kernel_size=3, bias=False))
|
||||||
self.final = ConvGnSilu(64, transform_count, bias=False, gn=False, silu=False)
|
self.final = ConvGnSilu(64, transform_count, bias=False, norm=False, activation=False)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
spine = self.backbone.get_forward_result()
|
spine = self.backbone.get_forward_result()
|
||||||
|
@ -139,7 +125,7 @@ class ConfigurableSwitchComputer(nn.Module):
|
||||||
# And the switch itself, including learned scalars
|
# And the switch itself, including learned scalars
|
||||||
self.switch = BareConvSwitch(initial_temperature=init_temp)
|
self.switch = BareConvSwitch(initial_temperature=init_temp)
|
||||||
self.switch_scale = nn.Parameter(torch.full((1,), float(1)))
|
self.switch_scale = nn.Parameter(torch.full((1,), float(1)))
|
||||||
self.post_switch_conv = ConvBnLelu(base_filters, base_filters, bn=False, bias=True)
|
self.post_switch_conv = ConvBnLelu(base_filters, base_filters, norm=False, bias=True)
|
||||||
# The post_switch_conv gets a low scale initially. The network can decide to magnify it (or not)
|
# The post_switch_conv gets a low scale initially. The network can decide to magnify it (or not)
|
||||||
# depending on its needs.
|
# depending on its needs.
|
||||||
self.psc_scale = nn.Parameter(torch.full((1,), float(.1)))
|
self.psc_scale = nn.Parameter(torch.full((1,), float(.1)))
|
||||||
|
@ -174,11 +160,11 @@ class ConfigurableSwitchedResidualGenerator2(nn.Module):
|
||||||
add_scalable_noise_to_transforms=False):
|
add_scalable_noise_to_transforms=False):
|
||||||
super(ConfigurableSwitchedResidualGenerator2, self).__init__()
|
super(ConfigurableSwitchedResidualGenerator2, self).__init__()
|
||||||
switches = []
|
switches = []
|
||||||
self.initial_conv = ConvBnLelu(3, transformation_filters, bn=False, lelu=False, bias=True)
|
self.initial_conv = ConvBnLelu(3, transformation_filters, norm=False, activation=False, bias=True)
|
||||||
self.upconv1 = ConvBnLelu(transformation_filters, transformation_filters, bn=False, bias=True)
|
self.upconv1 = ConvBnLelu(transformation_filters, transformation_filters, norm=False, bias=True)
|
||||||
self.upconv2 = ConvBnLelu(transformation_filters, transformation_filters, bn=False, bias=True)
|
self.upconv2 = ConvBnLelu(transformation_filters, transformation_filters, norm=False, bias=True)
|
||||||
self.hr_conv = ConvBnLelu(transformation_filters, transformation_filters, bn=False, bias=True)
|
self.hr_conv = ConvBnLelu(transformation_filters, transformation_filters, norm=False, bias=True)
|
||||||
self.final_conv = ConvBnLelu(transformation_filters, 3, bn=False, lelu=False, bias=True)
|
self.final_conv = ConvBnLelu(transformation_filters, 3, norm=False, activation=False, bias=True)
|
||||||
for _ in range(switch_depth):
|
for _ in range(switch_depth):
|
||||||
multiplx_fn = functools.partial(ConvBasisMultiplexer, transformation_filters, switch_filters, switch_reductions, switch_processing_layers, trans_counts)
|
multiplx_fn = functools.partial(ConvBasisMultiplexer, transformation_filters, switch_filters, switch_reductions, switch_processing_layers, trans_counts)
|
||||||
pretransform_fn = functools.partial(ConvBnLelu, transformation_filters, transformation_filters, bn=False, bias=False, weight_init_factor=.1)
|
pretransform_fn = functools.partial(ConvBnLelu, transformation_filters, transformation_filters, bn=False, bias=False, weight_init_factor=.1)
|
||||||
|
@ -258,19 +244,19 @@ class ConfigurableSwitchedResidualGenerator3(nn.Module):
|
||||||
heightened_temp_min=1,
|
heightened_temp_min=1,
|
||||||
heightened_final_step=50000, upsample_factor=4):
|
heightened_final_step=50000, upsample_factor=4):
|
||||||
super(ConfigurableSwitchedResidualGenerator3, self).__init__()
|
super(ConfigurableSwitchedResidualGenerator3, self).__init__()
|
||||||
self.initial_conv = ConvBnLelu(3, base_filters, bn=False, lelu=False, bias=True)
|
self.initial_conv = ConvBnLelu(3, base_filters, norm=False, activation=False, bias=True)
|
||||||
self.sw_conv = ConvBnLelu(base_filters, base_filters, lelu=False, bias=True)
|
self.sw_conv = ConvBnLelu(base_filters, base_filters, activation=False, bias=True)
|
||||||
self.upconv1 = ConvBnLelu(base_filters, base_filters, bn=False, bias=True)
|
self.upconv1 = ConvBnLelu(base_filters, base_filters, norm=False, bias=True)
|
||||||
self.upconv2 = ConvBnLelu(base_filters, base_filters, bn=False, bias=True)
|
self.upconv2 = ConvBnLelu(base_filters, base_filters, norm=False, bias=True)
|
||||||
self.hr_conv = ConvBnLelu(base_filters, base_filters, bn=False, bias=True)
|
self.hr_conv = ConvBnLelu(base_filters, base_filters, norm=False, bias=True)
|
||||||
self.final_conv = ConvBnLelu(base_filters, 3, bn=False, lelu=False, bias=True)
|
self.final_conv = ConvBnLelu(base_filters, 3, norm=False, activation=False, bias=True)
|
||||||
|
|
||||||
self.backbone = SpineNet('49', in_channels=3, use_input_norm=True)
|
self.backbone = SpineNet('49', in_channels=3, use_input_norm=True)
|
||||||
for p in self.backbone.parameters(recurse=True):
|
for p in self.backbone.parameters(recurse=True):
|
||||||
p.requires_grad = False
|
p.requires_grad = False
|
||||||
self.backbone_wrapper = CachedBackboneWrapper(self.backbone)
|
self.backbone_wrapper = CachedBackboneWrapper(self.backbone)
|
||||||
multiplx_fn = functools.partial(BackboneMultiplexer, self.backbone_wrapper)
|
multiplx_fn = functools.partial(BackboneMultiplexer, self.backbone_wrapper)
|
||||||
pretransform_fn = functools.partial(nn.Sequential, ConvBnLelu(base_filters, base_filters, kernel_size=3, bn=False, lelu=False, bias=False))
|
pretransform_fn = functools.partial(nn.Sequential, ConvBnLelu(base_filters, base_filters, kernel_size=3, norm=False, activation=False, bias=False))
|
||||||
transform_fn = functools.partial(MultiConvBlock, base_filters, int(base_filters * 1.5), base_filters, kernel_size=3, depth=4)
|
transform_fn = functools.partial(MultiConvBlock, base_filters, int(base_filters * 1.5), base_filters, kernel_size=3, depth=4)
|
||||||
self.switch = ConfigurableSwitchComputer(base_filters, multiplx_fn, pretransform_fn, transform_fn, trans_count, init_temp=initial_temp,
|
self.switch = ConfigurableSwitchComputer(base_filters, multiplx_fn, pretransform_fn, transform_fn, trans_count, init_temp=initial_temp,
|
||||||
add_scalable_noise_to_transforms=True, init_scalar=.1)
|
add_scalable_noise_to_transforms=True, init_scalar=.1)
|
||||||
|
|
|
@ -184,16 +184,16 @@ class SiLU(nn.Module):
|
||||||
''' Convenience class with Conv->BN->ReLU. Includes weight initialization and auto-padding for standard
|
''' Convenience class with Conv->BN->ReLU. Includes weight initialization and auto-padding for standard
|
||||||
kernel sizes. '''
|
kernel sizes. '''
|
||||||
class ConvBnRelu(nn.Module):
|
class ConvBnRelu(nn.Module):
|
||||||
def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, relu=True, bn=True, bias=True):
|
def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, activation=True, norm=True, bias=True):
|
||||||
super(ConvBnRelu, self).__init__()
|
super(ConvBnRelu, self).__init__()
|
||||||
padding_map = {1: 0, 3: 1, 5: 2, 7: 3}
|
padding_map = {1: 0, 3: 1, 5: 2, 7: 3}
|
||||||
assert kernel_size in padding_map.keys()
|
assert kernel_size in padding_map.keys()
|
||||||
self.conv = nn.Conv2d(filters_in, filters_out, kernel_size, stride, padding_map[kernel_size], bias=bias)
|
self.conv = nn.Conv2d(filters_in, filters_out, kernel_size, stride, padding_map[kernel_size], bias=bias)
|
||||||
if bn:
|
if norm:
|
||||||
self.bn = nn.BatchNorm2d(filters_out)
|
self.bn = nn.BatchNorm2d(filters_out)
|
||||||
else:
|
else:
|
||||||
self.bn = None
|
self.bn = None
|
||||||
if relu:
|
if activation:
|
||||||
self.relu = nn.ReLU()
|
self.relu = nn.ReLU()
|
||||||
else:
|
else:
|
||||||
self.relu = None
|
self.relu = None
|
||||||
|
@ -219,16 +219,16 @@ class ConvBnRelu(nn.Module):
|
||||||
''' Convenience class with Conv->BN->SiLU. Includes weight initialization and auto-padding for standard
|
''' Convenience class with Conv->BN->SiLU. Includes weight initialization and auto-padding for standard
|
||||||
kernel sizes. '''
|
kernel sizes. '''
|
||||||
class ConvBnSilu(nn.Module):
|
class ConvBnSilu(nn.Module):
|
||||||
def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, silu=True, bn=True, bias=True, weight_init_factor=1):
|
def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, activation=True, norm=True, bias=True, weight_init_factor=1):
|
||||||
super(ConvBnSilu, self).__init__()
|
super(ConvBnSilu, self).__init__()
|
||||||
padding_map = {1: 0, 3: 1, 5: 2, 7: 3}
|
padding_map = {1: 0, 3: 1, 5: 2, 7: 3}
|
||||||
assert kernel_size in padding_map.keys()
|
assert kernel_size in padding_map.keys()
|
||||||
self.conv = nn.Conv2d(filters_in, filters_out, kernel_size, stride, padding_map[kernel_size], bias=bias)
|
self.conv = nn.Conv2d(filters_in, filters_out, kernel_size, stride, padding_map[kernel_size], bias=bias)
|
||||||
if bn:
|
if norm:
|
||||||
self.bn = nn.BatchNorm2d(filters_out)
|
self.bn = nn.BatchNorm2d(filters_out)
|
||||||
else:
|
else:
|
||||||
self.bn = None
|
self.bn = None
|
||||||
if silu:
|
if activation:
|
||||||
self.silu = SiLU()
|
self.silu = SiLU()
|
||||||
else:
|
else:
|
||||||
self.silu = None
|
self.silu = None
|
||||||
|
@ -257,16 +257,16 @@ class ConvBnSilu(nn.Module):
|
||||||
''' Convenience class with Conv->BN->LeakyReLU. Includes weight initialization and auto-padding for standard
|
''' Convenience class with Conv->BN->LeakyReLU. Includes weight initialization and auto-padding for standard
|
||||||
kernel sizes. '''
|
kernel sizes. '''
|
||||||
class ConvBnLelu(nn.Module):
|
class ConvBnLelu(nn.Module):
|
||||||
def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, lelu=True, bn=True, bias=True, weight_init_factor=1):
|
def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, activation=True, norm=True, bias=True, weight_init_factor=1):
|
||||||
super(ConvBnLelu, self).__init__()
|
super(ConvBnLelu, self).__init__()
|
||||||
padding_map = {1: 0, 3: 1, 5: 2, 7: 3}
|
padding_map = {1: 0, 3: 1, 5: 2, 7: 3}
|
||||||
assert kernel_size in padding_map.keys()
|
assert kernel_size in padding_map.keys()
|
||||||
self.conv = nn.Conv2d(filters_in, filters_out, kernel_size, stride, padding_map[kernel_size], bias=bias)
|
self.conv = nn.Conv2d(filters_in, filters_out, kernel_size, stride, padding_map[kernel_size], bias=bias)
|
||||||
if bn:
|
if norm:
|
||||||
self.bn = nn.BatchNorm2d(filters_out)
|
self.bn = nn.BatchNorm2d(filters_out)
|
||||||
else:
|
else:
|
||||||
self.bn = None
|
self.bn = None
|
||||||
if lelu:
|
if activation:
|
||||||
self.lelu = nn.LeakyReLU(negative_slope=.1)
|
self.lelu = nn.LeakyReLU(negative_slope=.1)
|
||||||
else:
|
else:
|
||||||
self.lelu = None
|
self.lelu = None
|
||||||
|
@ -296,16 +296,16 @@ class ConvBnLelu(nn.Module):
|
||||||
''' Convenience class with Conv->GroupNorm->LeakyReLU. Includes weight initialization and auto-padding for standard
|
''' Convenience class with Conv->GroupNorm->LeakyReLU. Includes weight initialization and auto-padding for standard
|
||||||
kernel sizes. '''
|
kernel sizes. '''
|
||||||
class ConvGnLelu(nn.Module):
|
class ConvGnLelu(nn.Module):
|
||||||
def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, lelu=True, gn=True, bias=True, num_groups=8):
|
def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, activation=True, norm=True, bias=True, num_groups=8):
|
||||||
super(ConvGnLelu, self).__init__()
|
super(ConvGnLelu, self).__init__()
|
||||||
padding_map = {1: 0, 3: 1, 5: 2, 7: 3}
|
padding_map = {1: 0, 3: 1, 5: 2, 7: 3}
|
||||||
assert kernel_size in padding_map.keys()
|
assert kernel_size in padding_map.keys()
|
||||||
self.conv = nn.Conv2d(filters_in, filters_out, kernel_size, stride, padding_map[kernel_size], bias=bias)
|
self.conv = nn.Conv2d(filters_in, filters_out, kernel_size, stride, padding_map[kernel_size], bias=bias)
|
||||||
if gn:
|
if norm:
|
||||||
self.gn = nn.GroupNorm(num_groups, filters_out)
|
self.gn = nn.GroupNorm(num_groups, filters_out)
|
||||||
else:
|
else:
|
||||||
self.gn = None
|
self.gn = None
|
||||||
if lelu:
|
if activation:
|
||||||
self.lelu = nn.LeakyReLU(negative_slope=.1)
|
self.lelu = nn.LeakyReLU(negative_slope=.1)
|
||||||
else:
|
else:
|
||||||
self.lelu = None
|
self.lelu = None
|
||||||
|
@ -331,16 +331,16 @@ class ConvGnLelu(nn.Module):
|
||||||
''' Convenience class with Conv->BN->SiLU. Includes weight initialization and auto-padding for standard
|
''' Convenience class with Conv->BN->SiLU. Includes weight initialization and auto-padding for standard
|
||||||
kernel sizes. '''
|
kernel sizes. '''
|
||||||
class ConvGnSilu(nn.Module):
|
class ConvGnSilu(nn.Module):
|
||||||
def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, silu=True, gn=True, bias=True, num_groups=8, weight_init_factor=1):
|
def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, activation=True, norm=True, bias=True, num_groups=8, weight_init_factor=1):
|
||||||
super(ConvGnSilu, self).__init__()
|
super(ConvGnSilu, self).__init__()
|
||||||
padding_map = {1: 0, 3: 1, 5: 2, 7: 3}
|
padding_map = {1: 0, 3: 1, 5: 2, 7: 3}
|
||||||
assert kernel_size in padding_map.keys()
|
assert kernel_size in padding_map.keys()
|
||||||
self.conv = nn.Conv2d(filters_in, filters_out, kernel_size, stride, padding_map[kernel_size], bias=bias)
|
self.conv = nn.Conv2d(filters_in, filters_out, kernel_size, stride, padding_map[kernel_size], bias=bias)
|
||||||
if gn:
|
if norm:
|
||||||
self.gn = nn.GroupNorm(num_groups, filters_out)
|
self.gn = nn.GroupNorm(num_groups, filters_out)
|
||||||
else:
|
else:
|
||||||
self.gn = None
|
self.gn = None
|
||||||
if silu:
|
if activation:
|
||||||
self.silu = SiLU()
|
self.silu = SiLU()
|
||||||
else:
|
else:
|
||||||
self.silu = None
|
self.silu = None
|
||||||
|
@ -363,4 +363,24 @@ class ConvGnSilu(nn.Module):
|
||||||
if self.silu:
|
if self.silu:
|
||||||
return self.silu(x)
|
return self.silu(x)
|
||||||
else:
|
else:
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
# Block that upsamples 2x and reduces incoming filters by 2x. It preserves structure by taking a passthrough feed
|
||||||
|
# along with the feature representation.
|
||||||
|
class ExpansionBlock(nn.Module):
|
||||||
|
def __init__(self, filters, block=ConvGnSilu):
|
||||||
|
super(ExpansionBlock, self).__init__()
|
||||||
|
self.decimate = block(filters, filters // 2, kernel_size=1, bias=False, activation=False, norm=True)
|
||||||
|
self.process_passthrough = block(filters // 2, filters // 2, kernel_size=3, bias=True, activation=False, norm=True)
|
||||||
|
self.conjoin = block(filters, filters // 2, kernel_size=3, bias=False, activation=True, norm=False)
|
||||||
|
self.process = block(filters // 2, filters // 2, kernel_size=3, bias=False, activation=True, norm=True)
|
||||||
|
|
||||||
|
# input is the feature signal with shape (b, f, w, h)
|
||||||
|
# passthrough is the structure signal with shape (b, f/2, w*2, h*2)
|
||||||
|
# output is conjoined upsample with shape (b, f/2, w*2, h*2)
|
||||||
|
def forward(self, input, passthrough):
|
||||||
|
x = F.interpolate(input, scale_factor=2, mode="nearest")
|
||||||
|
x = self.decimate(x)
|
||||||
|
p = self.process_passthrough(passthrough)
|
||||||
|
x = self.conjoin(torch.cat([x, p], dim=1))
|
||||||
|
return self.process(x)
|
|
@ -108,20 +108,20 @@ class Discriminator_VGG_PixLoss(nn.Module):
|
||||||
self.bn4_1 = nn.GroupNorm(8, nf * 8, affine=True)
|
self.bn4_1 = nn.GroupNorm(8, nf * 8, affine=True)
|
||||||
|
|
||||||
self.reduce_1 = ConvGnLelu(nf * 8, nf * 4, bias=False)
|
self.reduce_1 = ConvGnLelu(nf * 8, nf * 4, bias=False)
|
||||||
self.pix_loss_collapse = ConvGnLelu(nf * 4, 1, bias=False, gn=False, lelu=False)
|
self.pix_loss_collapse = ConvGnLelu(nf * 4, 1, bias=False, norm=False, activation=False)
|
||||||
|
|
||||||
# Pyramid network: upsample with residuals and produce losses at multiple resolutions.
|
# Pyramid network: upsample with residuals and produce losses at multiple resolutions.
|
||||||
self.up3_decimate = ConvGnLelu(nf * 8, nf * 8, kernel_size=3, bias=True, lelu=False)
|
self.up3_decimate = ConvGnLelu(nf * 8, nf * 8, kernel_size=3, bias=True, activation=False)
|
||||||
self.up3_converge = ConvGnLelu(nf * 16, nf * 8, kernel_size=3, bias=False)
|
self.up3_converge = ConvGnLelu(nf * 16, nf * 8, kernel_size=3, bias=False)
|
||||||
self.up3_proc = ConvGnLelu(nf * 8, nf * 8, bias=False)
|
self.up3_proc = ConvGnLelu(nf * 8, nf * 8, bias=False)
|
||||||
self.up3_reduce = ConvGnLelu(nf * 8, nf * 4, bias=False)
|
self.up3_reduce = ConvGnLelu(nf * 8, nf * 4, bias=False)
|
||||||
self.up3_pix = ConvGnLelu(nf * 4, 1, bias=False, gn=False, lelu=False)
|
self.up3_pix = ConvGnLelu(nf * 4, 1, bias=False, norm=False, activation=False)
|
||||||
|
|
||||||
self.up2_decimate = ConvGnLelu(nf * 8, nf * 4, kernel_size=1, bias=True, lelu=False)
|
self.up2_decimate = ConvGnLelu(nf * 8, nf * 4, kernel_size=1, bias=True, activation=False)
|
||||||
self.up2_converge = ConvGnLelu(nf * 8, nf * 4, kernel_size=3, bias=False)
|
self.up2_converge = ConvGnLelu(nf * 8, nf * 4, kernel_size=3, bias=False)
|
||||||
self.up2_proc = ConvGnLelu(nf * 4, nf * 4, bias=False)
|
self.up2_proc = ConvGnLelu(nf * 4, nf * 4, bias=False)
|
||||||
self.up2_reduce = ConvGnLelu(nf * 4, nf * 2, bias=False)
|
self.up2_reduce = ConvGnLelu(nf * 4, nf * 2, bias=False)
|
||||||
self.up2_pix = ConvGnLelu(nf * 2, 1, bias=False, gn=False, lelu=False)
|
self.up2_pix = ConvGnLelu(nf * 2, 1, bias=False, norm=False, activation=False)
|
||||||
|
|
||||||
# activation function
|
# activation function
|
||||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user