Move ExpansionBlock to arch_util

Also makes all processing blocks have a conformant signature.

Alters ExpansionBlock to perform a processing conv on the passthrough
before the conjoin operation - this will break backwards compatibilty with SRG2.
This commit is contained in:
James Betker 2020-07-10 15:53:41 -06:00
parent 5e8b52f34c
commit 33ca3832e1
5 changed files with 81 additions and 75 deletions

View File

@ -83,9 +83,9 @@ class Constrictor(nn.Module):
assert(filters > output_filters) assert(filters > output_filters)
gap = filters - output_filters gap = filters - output_filters
gap_div_4 = int(gap / 4) gap_div_4 = int(gap / 4)
self.cbl1 = ConvBnRelu(filters, filters - (gap_div_4 * 2), kernel_size=1, bn=True, bias=True) self.cbl1 = ConvBnRelu(filters, filters - (gap_div_4 * 2), kernel_size=1, norm=True, bias=True)
self.cbl2 = ConvBnRelu(filters - (gap_div_4 * 2), filters - (gap_div_4 * 3), kernel_size=1, bn=True, bias=False) self.cbl2 = ConvBnRelu(filters - (gap_div_4 * 2), filters - (gap_div_4 * 3), kernel_size=1, norm=True, bias=False)
self.cbl3 = ConvBnRelu(filters - (gap_div_4 * 3), output_filters, kernel_size=1, relu=False, bn=False, bias=False) self.cbl3 = ConvBnRelu(filters - (gap_div_4 * 3), output_filters, kernel_size=1, activation=False, norm=False, bias=False)
def forward(self, x): def forward(self, x):
x = self.cbl1(x) x = self.cbl1(x)
@ -134,10 +134,10 @@ class NestedSwitchComputer(nn.Module):
filters.append(current_filters) filters.append(current_filters)
reduce = True reduce = True
self.multiplexer_init_conv = ConvBnLelu(transform_filters, switch_base_filters, kernel_size=7, lelu=False, bn=False) self.multiplexer_init_conv = ConvBnLelu(transform_filters, switch_base_filters, kernel_size=7, activation=False, norm=False)
self.processing_trunk = nn.ModuleList(processing_trunk) self.processing_trunk = nn.ModuleList(processing_trunk)
self.switch = RecursiveSwitchedTransform(transform_filters, filters, nesting_depth-1, transforms_at_leaf, trans_kernel_size, trans_num_layers-1, trans_scale_init, initial_temp=initial_temp, add_scalable_noise_to_transforms=add_scalable_noise_to_transforms) self.switch = RecursiveSwitchedTransform(transform_filters, filters, nesting_depth-1, transforms_at_leaf, trans_kernel_size, trans_num_layers-1, trans_scale_init, initial_temp=initial_temp, add_scalable_noise_to_transforms=add_scalable_noise_to_transforms)
self.anneal = ConvBnLelu(transform_filters, transform_filters, kernel_size=1, bn=False) self.anneal = ConvBnLelu(transform_filters, transform_filters, kernel_size=1, norm=False)
def forward(self, x): def forward(self, x):
feed_forward = x feed_forward = x
@ -161,9 +161,9 @@ class NestedSwitchedGenerator(nn.Module):
trans_layers, transformation_filters, initial_temp=20, final_temperature_step=50000, heightened_temp_min=1, trans_layers, transformation_filters, initial_temp=20, final_temperature_step=50000, heightened_temp_min=1,
heightened_final_step=50000, upsample_factor=1, add_scalable_noise_to_transforms=False): heightened_final_step=50000, upsample_factor=1, add_scalable_noise_to_transforms=False):
super(NestedSwitchedGenerator, self).__init__() super(NestedSwitchedGenerator, self).__init__()
self.initial_conv = ConvBnLelu(3, transformation_filters, kernel_size=7, lelu=False, bn=False) self.initial_conv = ConvBnLelu(3, transformation_filters, kernel_size=7, activation=False, norm=False)
self.proc_conv = ConvBnLelu(transformation_filters, transformation_filters, bn=False) self.proc_conv = ConvBnLelu(transformation_filters, transformation_filters, norm=False)
self.final_conv = ConvBnLelu(transformation_filters, 3, kernel_size=1, lelu=False, bn=False) self.final_conv = ConvBnLelu(transformation_filters, 3, kernel_size=1, activation=False, norm=False)
switches = [] switches = []
for sw_reduce, sw_proc, trans_count, kernel, layers in zip(switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes, trans_layers): for sw_reduce, sw_proc, trans_count, kernel, layers in zip(switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes, trans_layers):

View File

@ -12,9 +12,9 @@ class MultiConvBlock(nn.Module):
assert depth >= 2 assert depth >= 2
super(MultiConvBlock, self).__init__() super(MultiConvBlock, self).__init__()
self.noise_scale = nn.Parameter(torch.full((1,), fill_value=.01)) self.noise_scale = nn.Parameter(torch.full((1,), fill_value=.01))
self.bnconvs = nn.ModuleList([ConvBnLelu(filters_in, filters_mid, kernel_size, bn=bn, bias=False)] + self.bnconvs = nn.ModuleList([ConvBnLelu(filters_in, filters_mid, kernel_size, norm=bn, bias=False)] +
[ConvBnLelu(filters_mid, filters_mid, kernel_size, bn=bn, bias=False) for i in range(depth-2)] + [ConvBnLelu(filters_mid, filters_mid, kernel_size, norm=bn, bias=False) for i in range(depth - 2)] +
[ConvBnLelu(filters_mid, filters_out, kernel_size, lelu=False, bn=False, bias=False)]) [ConvBnLelu(filters_mid, filters_out, kernel_size, activation=False, norm=False, bias=False)])
self.scale = nn.Parameter(torch.full((1,), fill_value=scale_init)) self.scale = nn.Parameter(torch.full((1,), fill_value=scale_init))
self.bias = nn.Parameter(torch.zeros(1)) self.bias = nn.Parameter(torch.zeros(1))
@ -32,8 +32,8 @@ class MultiConvBlock(nn.Module):
class HalvingProcessingBlock(nn.Module): class HalvingProcessingBlock(nn.Module):
def __init__(self, filters): def __init__(self, filters):
super(HalvingProcessingBlock, self).__init__() super(HalvingProcessingBlock, self).__init__()
self.bnconv1 = ConvBnSilu(filters, filters * 2, stride=2, bn=False, bias=False) self.bnconv1 = ConvBnSilu(filters, filters * 2, stride=2, norm=False, bias=False)
self.bnconv2 = ConvBnSilu(filters * 2, filters * 2, bn=True, bias=False) self.bnconv2 = ConvBnSilu(filters * 2, filters * 2, norm=True, bias=False)
def forward(self, x): def forward(self, x):
x = self.bnconv1(x) x = self.bnconv1(x)
return self.bnconv2(x) return self.bnconv2(x)
@ -45,7 +45,7 @@ def create_sequential_growing_processing_block(filters_init, filter_growth, num_
convs = [] convs = []
current_filters = filters_init current_filters = filters_init
for i in range(num_convs): for i in range(num_convs):
convs.append(ConvBnSilu(current_filters, current_filters + filter_growth, bn=True, bias=False)) convs.append(ConvBnSilu(current_filters, current_filters + filter_growth, norm=True, bias=False))
current_filters += filter_growth current_filters += filter_growth
return nn.Sequential(*convs), current_filters return nn.Sequential(*convs), current_filters
@ -60,7 +60,7 @@ class SwitchComputer(nn.Module):
self.reduction_blocks = nn.ModuleList([HalvingProcessingBlock(filters * 2 ** i) for i in range(reduction_blocks)]) self.reduction_blocks = nn.ModuleList([HalvingProcessingBlock(filters * 2 ** i) for i in range(reduction_blocks)])
final_filters = filters * 2 ** reduction_blocks final_filters = filters * 2 ** reduction_blocks
self.processing_blocks, final_filters = create_sequential_growing_processing_block(final_filters, growth, processing_blocks) self.processing_blocks, final_filters = create_sequential_growing_processing_block(final_filters, growth, processing_blocks)
self.post_interpolate_decimate = ConvBnSilu(final_filters, filters, kernel_size=1, silu=False, bn=False) self.post_interpolate_decimate = ConvBnSilu(final_filters, filters, kernel_size=1, activation=False, norm=False)
self.interpolate_process = ConvBnSilu(filters, filters) self.interpolate_process = ConvBnSilu(filters, filters)
self.interpolate_process2 = ConvBnSilu(filters, filters) self.interpolate_process2 = ConvBnSilu(filters, filters)
tc = transform_count tc = transform_count

View File

@ -4,7 +4,7 @@ from switched_conv import BareConvSwitch, compute_attention_specificity
import torch.nn.functional as F import torch.nn.functional as F
import functools import functools
from collections import OrderedDict from collections import OrderedDict
from models.archs.arch_util import ConvBnLelu, ConvGnSilu from models.archs.arch_util import ConvBnLelu, ConvGnSilu, ExpansionBlock
from models.archs.RRDBNet_arch import ResidualDenseBlock_5C from models.archs.RRDBNet_arch import ResidualDenseBlock_5C
from models.archs.spinenet_arch import SpineNet from models.archs.spinenet_arch import SpineNet
from switched_conv_util import save_attention_to_image from switched_conv_util import save_attention_to_image
@ -15,9 +15,9 @@ class MultiConvBlock(nn.Module):
assert depth >= 2 assert depth >= 2
super(MultiConvBlock, self).__init__() super(MultiConvBlock, self).__init__()
self.noise_scale = nn.Parameter(torch.full((1,), fill_value=.01)) self.noise_scale = nn.Parameter(torch.full((1,), fill_value=.01))
self.bnconvs = nn.ModuleList([ConvBnLelu(filters_in, filters_mid, kernel_size, bn=bn, bias=False, weight_init_factor=weight_init_factor)] + self.bnconvs = nn.ModuleList([ConvBnLelu(filters_in, filters_mid, kernel_size, norm=bn, bias=False, weight_init_factor=weight_init_factor)] +
[ConvBnLelu(filters_mid, filters_mid, kernel_size, bn=bn, bias=False, weight_init_factor=weight_init_factor) for i in range(depth-2)] + [ConvBnLelu(filters_mid, filters_mid, kernel_size, norm=bn, bias=False, weight_init_factor=weight_init_factor) for i in range(depth - 2)] +
[ConvBnLelu(filters_mid, filters_out, kernel_size, lelu=False, bn=False, bias=False, weight_init_factor=weight_init_factor)]) [ConvBnLelu(filters_mid, filters_out, kernel_size, activation=False, norm=False, bias=False, weight_init_factor=weight_init_factor)])
self.scale = nn.Parameter(torch.full((1,), fill_value=scale_init)) self.scale = nn.Parameter(torch.full((1,), fill_value=scale_init))
self.bias = nn.Parameter(torch.zeros(1)) self.bias = nn.Parameter(torch.zeros(1))
@ -35,28 +35,14 @@ class MultiConvBlock(nn.Module):
class HalvingProcessingBlock(nn.Module): class HalvingProcessingBlock(nn.Module):
def __init__(self, filters): def __init__(self, filters):
super(HalvingProcessingBlock, self).__init__() super(HalvingProcessingBlock, self).__init__()
self.bnconv1 = ConvGnSilu(filters, filters * 2, stride=2, gn=False, bias=False) self.bnconv1 = ConvGnSilu(filters, filters * 2, stride=2, norm=False, bias=False)
self.bnconv2 = ConvGnSilu(filters * 2, filters * 2, gn=True, bias=False) self.bnconv2 = ConvGnSilu(filters * 2, filters * 2, norm=True, bias=False)
def forward(self, x): def forward(self, x):
x = self.bnconv1(x) x = self.bnconv1(x)
return self.bnconv2(x) return self.bnconv2(x)
class ExpansionBlock(nn.Module):
def __init__(self, filters):
super(ExpansionBlock, self).__init__()
self.decimate = ConvGnSilu(filters, filters // 2, kernel_size=1, bias=False, silu=False, gn=False)
self.conjoin = ConvGnSilu(filters, filters // 2, kernel_size=3, bias=True, silu=False, gn=True)
self.process = ConvGnSilu(filters // 2, filters // 2, kernel_size=3, bias=False, silu=True, gn=True)
def forward(self, input, passthrough):
x = F.interpolate(input, scale_factor=2, mode="nearest")
x = self.decimate(x)
x = self.conjoin(torch.cat([x, passthrough], dim=1))
return self.process(x)
# This is a classic u-net architecture with the goal of assigning each individual pixel an individual transform # This is a classic u-net architecture with the goal of assigning each individual pixel an individual transform
# switching set. # switching set.
class ConvBasisMultiplexer(nn.Module): class ConvBasisMultiplexer(nn.Module):
@ -70,10 +56,10 @@ class ConvBasisMultiplexer(nn.Module):
gap = base_filters - multiplexer_channels gap = base_filters - multiplexer_channels
cbl1_out = ((base_filters - (gap // 2)) // 4) * 4 # Must be multiples of 4 to use with group norm. cbl1_out = ((base_filters - (gap // 2)) // 4) * 4 # Must be multiples of 4 to use with group norm.
self.cbl1 = ConvGnSilu(base_filters, cbl1_out, gn=use_gn, bias=False, num_groups=4) self.cbl1 = ConvGnSilu(base_filters, cbl1_out, norm=use_gn, bias=False, num_groups=4)
cbl2_out = ((base_filters - (3 * gap // 4)) // 4) * 4 cbl2_out = ((base_filters - (3 * gap // 4)) // 4) * 4
self.cbl2 = ConvGnSilu(cbl1_out, cbl2_out, gn=use_gn, bias=False, num_groups=4) self.cbl2 = ConvGnSilu(cbl1_out, cbl2_out, norm=use_gn, bias=False, num_groups=4)
self.cbl3 = ConvGnSilu(cbl2_out, multiplexer_channels, bias=True, gn=False) self.cbl3 = ConvGnSilu(cbl2_out, multiplexer_channels, bias=True, norm=False)
def forward(self, x): def forward(self, x):
x = self.filter_conv(x) x = self.filter_conv(x)
@ -109,11 +95,11 @@ class BackboneMultiplexer(nn.Module):
self.backbone = backbone self.backbone = backbone
self.proc = nn.Sequential(ConvGnSilu(256, 256, kernel_size=3, bias=True), self.proc = nn.Sequential(ConvGnSilu(256, 256, kernel_size=3, bias=True),
ConvGnSilu(256, 256, kernel_size=3, bias=False)) ConvGnSilu(256, 256, kernel_size=3, bias=False))
self.up1 = nn.Sequential(ConvGnSilu(256, 128, kernel_size=3, bias=False, gn=False, silu=False), self.up1 = nn.Sequential(ConvGnSilu(256, 128, kernel_size=3, bias=False, norm=False, activation=False),
ConvGnSilu(128, 128, kernel_size=3, bias=False)) ConvGnSilu(128, 128, kernel_size=3, bias=False))
self.up2 = nn.Sequential(ConvGnSilu(128, 64, kernel_size=3, bias=False, gn=False, silu=False), self.up2 = nn.Sequential(ConvGnSilu(128, 64, kernel_size=3, bias=False, norm=False, activation=False),
ConvGnSilu(64, 64, kernel_size=3, bias=False)) ConvGnSilu(64, 64, kernel_size=3, bias=False))
self.final = ConvGnSilu(64, transform_count, bias=False, gn=False, silu=False) self.final = ConvGnSilu(64, transform_count, bias=False, norm=False, activation=False)
def forward(self, x): def forward(self, x):
spine = self.backbone.get_forward_result() spine = self.backbone.get_forward_result()
@ -139,7 +125,7 @@ class ConfigurableSwitchComputer(nn.Module):
# And the switch itself, including learned scalars # And the switch itself, including learned scalars
self.switch = BareConvSwitch(initial_temperature=init_temp) self.switch = BareConvSwitch(initial_temperature=init_temp)
self.switch_scale = nn.Parameter(torch.full((1,), float(1))) self.switch_scale = nn.Parameter(torch.full((1,), float(1)))
self.post_switch_conv = ConvBnLelu(base_filters, base_filters, bn=False, bias=True) self.post_switch_conv = ConvBnLelu(base_filters, base_filters, norm=False, bias=True)
# The post_switch_conv gets a low scale initially. The network can decide to magnify it (or not) # The post_switch_conv gets a low scale initially. The network can decide to magnify it (or not)
# depending on its needs. # depending on its needs.
self.psc_scale = nn.Parameter(torch.full((1,), float(.1))) self.psc_scale = nn.Parameter(torch.full((1,), float(.1)))
@ -174,11 +160,11 @@ class ConfigurableSwitchedResidualGenerator2(nn.Module):
add_scalable_noise_to_transforms=False): add_scalable_noise_to_transforms=False):
super(ConfigurableSwitchedResidualGenerator2, self).__init__() super(ConfigurableSwitchedResidualGenerator2, self).__init__()
switches = [] switches = []
self.initial_conv = ConvBnLelu(3, transformation_filters, bn=False, lelu=False, bias=True) self.initial_conv = ConvBnLelu(3, transformation_filters, norm=False, activation=False, bias=True)
self.upconv1 = ConvBnLelu(transformation_filters, transformation_filters, bn=False, bias=True) self.upconv1 = ConvBnLelu(transformation_filters, transformation_filters, norm=False, bias=True)
self.upconv2 = ConvBnLelu(transformation_filters, transformation_filters, bn=False, bias=True) self.upconv2 = ConvBnLelu(transformation_filters, transformation_filters, norm=False, bias=True)
self.hr_conv = ConvBnLelu(transformation_filters, transformation_filters, bn=False, bias=True) self.hr_conv = ConvBnLelu(transformation_filters, transformation_filters, norm=False, bias=True)
self.final_conv = ConvBnLelu(transformation_filters, 3, bn=False, lelu=False, bias=True) self.final_conv = ConvBnLelu(transformation_filters, 3, norm=False, activation=False, bias=True)
for _ in range(switch_depth): for _ in range(switch_depth):
multiplx_fn = functools.partial(ConvBasisMultiplexer, transformation_filters, switch_filters, switch_reductions, switch_processing_layers, trans_counts) multiplx_fn = functools.partial(ConvBasisMultiplexer, transformation_filters, switch_filters, switch_reductions, switch_processing_layers, trans_counts)
pretransform_fn = functools.partial(ConvBnLelu, transformation_filters, transformation_filters, bn=False, bias=False, weight_init_factor=.1) pretransform_fn = functools.partial(ConvBnLelu, transformation_filters, transformation_filters, bn=False, bias=False, weight_init_factor=.1)
@ -258,19 +244,19 @@ class ConfigurableSwitchedResidualGenerator3(nn.Module):
heightened_temp_min=1, heightened_temp_min=1,
heightened_final_step=50000, upsample_factor=4): heightened_final_step=50000, upsample_factor=4):
super(ConfigurableSwitchedResidualGenerator3, self).__init__() super(ConfigurableSwitchedResidualGenerator3, self).__init__()
self.initial_conv = ConvBnLelu(3, base_filters, bn=False, lelu=False, bias=True) self.initial_conv = ConvBnLelu(3, base_filters, norm=False, activation=False, bias=True)
self.sw_conv = ConvBnLelu(base_filters, base_filters, lelu=False, bias=True) self.sw_conv = ConvBnLelu(base_filters, base_filters, activation=False, bias=True)
self.upconv1 = ConvBnLelu(base_filters, base_filters, bn=False, bias=True) self.upconv1 = ConvBnLelu(base_filters, base_filters, norm=False, bias=True)
self.upconv2 = ConvBnLelu(base_filters, base_filters, bn=False, bias=True) self.upconv2 = ConvBnLelu(base_filters, base_filters, norm=False, bias=True)
self.hr_conv = ConvBnLelu(base_filters, base_filters, bn=False, bias=True) self.hr_conv = ConvBnLelu(base_filters, base_filters, norm=False, bias=True)
self.final_conv = ConvBnLelu(base_filters, 3, bn=False, lelu=False, bias=True) self.final_conv = ConvBnLelu(base_filters, 3, norm=False, activation=False, bias=True)
self.backbone = SpineNet('49', in_channels=3, use_input_norm=True) self.backbone = SpineNet('49', in_channels=3, use_input_norm=True)
for p in self.backbone.parameters(recurse=True): for p in self.backbone.parameters(recurse=True):
p.requires_grad = False p.requires_grad = False
self.backbone_wrapper = CachedBackboneWrapper(self.backbone) self.backbone_wrapper = CachedBackboneWrapper(self.backbone)
multiplx_fn = functools.partial(BackboneMultiplexer, self.backbone_wrapper) multiplx_fn = functools.partial(BackboneMultiplexer, self.backbone_wrapper)
pretransform_fn = functools.partial(nn.Sequential, ConvBnLelu(base_filters, base_filters, kernel_size=3, bn=False, lelu=False, bias=False)) pretransform_fn = functools.partial(nn.Sequential, ConvBnLelu(base_filters, base_filters, kernel_size=3, norm=False, activation=False, bias=False))
transform_fn = functools.partial(MultiConvBlock, base_filters, int(base_filters * 1.5), base_filters, kernel_size=3, depth=4) transform_fn = functools.partial(MultiConvBlock, base_filters, int(base_filters * 1.5), base_filters, kernel_size=3, depth=4)
self.switch = ConfigurableSwitchComputer(base_filters, multiplx_fn, pretransform_fn, transform_fn, trans_count, init_temp=initial_temp, self.switch = ConfigurableSwitchComputer(base_filters, multiplx_fn, pretransform_fn, transform_fn, trans_count, init_temp=initial_temp,
add_scalable_noise_to_transforms=True, init_scalar=.1) add_scalable_noise_to_transforms=True, init_scalar=.1)

View File

@ -184,16 +184,16 @@ class SiLU(nn.Module):
''' Convenience class with Conv->BN->ReLU. Includes weight initialization and auto-padding for standard ''' Convenience class with Conv->BN->ReLU. Includes weight initialization and auto-padding for standard
kernel sizes. ''' kernel sizes. '''
class ConvBnRelu(nn.Module): class ConvBnRelu(nn.Module):
def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, relu=True, bn=True, bias=True): def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, activation=True, norm=True, bias=True):
super(ConvBnRelu, self).__init__() super(ConvBnRelu, self).__init__()
padding_map = {1: 0, 3: 1, 5: 2, 7: 3} padding_map = {1: 0, 3: 1, 5: 2, 7: 3}
assert kernel_size in padding_map.keys() assert kernel_size in padding_map.keys()
self.conv = nn.Conv2d(filters_in, filters_out, kernel_size, stride, padding_map[kernel_size], bias=bias) self.conv = nn.Conv2d(filters_in, filters_out, kernel_size, stride, padding_map[kernel_size], bias=bias)
if bn: if norm:
self.bn = nn.BatchNorm2d(filters_out) self.bn = nn.BatchNorm2d(filters_out)
else: else:
self.bn = None self.bn = None
if relu: if activation:
self.relu = nn.ReLU() self.relu = nn.ReLU()
else: else:
self.relu = None self.relu = None
@ -219,16 +219,16 @@ class ConvBnRelu(nn.Module):
''' Convenience class with Conv->BN->SiLU. Includes weight initialization and auto-padding for standard ''' Convenience class with Conv->BN->SiLU. Includes weight initialization and auto-padding for standard
kernel sizes. ''' kernel sizes. '''
class ConvBnSilu(nn.Module): class ConvBnSilu(nn.Module):
def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, silu=True, bn=True, bias=True, weight_init_factor=1): def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, activation=True, norm=True, bias=True, weight_init_factor=1):
super(ConvBnSilu, self).__init__() super(ConvBnSilu, self).__init__()
padding_map = {1: 0, 3: 1, 5: 2, 7: 3} padding_map = {1: 0, 3: 1, 5: 2, 7: 3}
assert kernel_size in padding_map.keys() assert kernel_size in padding_map.keys()
self.conv = nn.Conv2d(filters_in, filters_out, kernel_size, stride, padding_map[kernel_size], bias=bias) self.conv = nn.Conv2d(filters_in, filters_out, kernel_size, stride, padding_map[kernel_size], bias=bias)
if bn: if norm:
self.bn = nn.BatchNorm2d(filters_out) self.bn = nn.BatchNorm2d(filters_out)
else: else:
self.bn = None self.bn = None
if silu: if activation:
self.silu = SiLU() self.silu = SiLU()
else: else:
self.silu = None self.silu = None
@ -257,16 +257,16 @@ class ConvBnSilu(nn.Module):
''' Convenience class with Conv->BN->LeakyReLU. Includes weight initialization and auto-padding for standard ''' Convenience class with Conv->BN->LeakyReLU. Includes weight initialization and auto-padding for standard
kernel sizes. ''' kernel sizes. '''
class ConvBnLelu(nn.Module): class ConvBnLelu(nn.Module):
def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, lelu=True, bn=True, bias=True, weight_init_factor=1): def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, activation=True, norm=True, bias=True, weight_init_factor=1):
super(ConvBnLelu, self).__init__() super(ConvBnLelu, self).__init__()
padding_map = {1: 0, 3: 1, 5: 2, 7: 3} padding_map = {1: 0, 3: 1, 5: 2, 7: 3}
assert kernel_size in padding_map.keys() assert kernel_size in padding_map.keys()
self.conv = nn.Conv2d(filters_in, filters_out, kernel_size, stride, padding_map[kernel_size], bias=bias) self.conv = nn.Conv2d(filters_in, filters_out, kernel_size, stride, padding_map[kernel_size], bias=bias)
if bn: if norm:
self.bn = nn.BatchNorm2d(filters_out) self.bn = nn.BatchNorm2d(filters_out)
else: else:
self.bn = None self.bn = None
if lelu: if activation:
self.lelu = nn.LeakyReLU(negative_slope=.1) self.lelu = nn.LeakyReLU(negative_slope=.1)
else: else:
self.lelu = None self.lelu = None
@ -296,16 +296,16 @@ class ConvBnLelu(nn.Module):
''' Convenience class with Conv->GroupNorm->LeakyReLU. Includes weight initialization and auto-padding for standard ''' Convenience class with Conv->GroupNorm->LeakyReLU. Includes weight initialization and auto-padding for standard
kernel sizes. ''' kernel sizes. '''
class ConvGnLelu(nn.Module): class ConvGnLelu(nn.Module):
def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, lelu=True, gn=True, bias=True, num_groups=8): def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, activation=True, norm=True, bias=True, num_groups=8):
super(ConvGnLelu, self).__init__() super(ConvGnLelu, self).__init__()
padding_map = {1: 0, 3: 1, 5: 2, 7: 3} padding_map = {1: 0, 3: 1, 5: 2, 7: 3}
assert kernel_size in padding_map.keys() assert kernel_size in padding_map.keys()
self.conv = nn.Conv2d(filters_in, filters_out, kernel_size, stride, padding_map[kernel_size], bias=bias) self.conv = nn.Conv2d(filters_in, filters_out, kernel_size, stride, padding_map[kernel_size], bias=bias)
if gn: if norm:
self.gn = nn.GroupNorm(num_groups, filters_out) self.gn = nn.GroupNorm(num_groups, filters_out)
else: else:
self.gn = None self.gn = None
if lelu: if activation:
self.lelu = nn.LeakyReLU(negative_slope=.1) self.lelu = nn.LeakyReLU(negative_slope=.1)
else: else:
self.lelu = None self.lelu = None
@ -331,16 +331,16 @@ class ConvGnLelu(nn.Module):
''' Convenience class with Conv->BN->SiLU. Includes weight initialization and auto-padding for standard ''' Convenience class with Conv->BN->SiLU. Includes weight initialization and auto-padding for standard
kernel sizes. ''' kernel sizes. '''
class ConvGnSilu(nn.Module): class ConvGnSilu(nn.Module):
def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, silu=True, gn=True, bias=True, num_groups=8, weight_init_factor=1): def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, activation=True, norm=True, bias=True, num_groups=8, weight_init_factor=1):
super(ConvGnSilu, self).__init__() super(ConvGnSilu, self).__init__()
padding_map = {1: 0, 3: 1, 5: 2, 7: 3} padding_map = {1: 0, 3: 1, 5: 2, 7: 3}
assert kernel_size in padding_map.keys() assert kernel_size in padding_map.keys()
self.conv = nn.Conv2d(filters_in, filters_out, kernel_size, stride, padding_map[kernel_size], bias=bias) self.conv = nn.Conv2d(filters_in, filters_out, kernel_size, stride, padding_map[kernel_size], bias=bias)
if gn: if norm:
self.gn = nn.GroupNorm(num_groups, filters_out) self.gn = nn.GroupNorm(num_groups, filters_out)
else: else:
self.gn = None self.gn = None
if silu: if activation:
self.silu = SiLU() self.silu = SiLU()
else: else:
self.silu = None self.silu = None
@ -363,4 +363,24 @@ class ConvGnSilu(nn.Module):
if self.silu: if self.silu:
return self.silu(x) return self.silu(x)
else: else:
return x return x
# Block that upsamples 2x and reduces incoming filters by 2x. It preserves structure by taking a passthrough feed
# along with the feature representation.
class ExpansionBlock(nn.Module):
def __init__(self, filters, block=ConvGnSilu):
super(ExpansionBlock, self).__init__()
self.decimate = block(filters, filters // 2, kernel_size=1, bias=False, activation=False, norm=True)
self.process_passthrough = block(filters // 2, filters // 2, kernel_size=3, bias=True, activation=False, norm=True)
self.conjoin = block(filters, filters // 2, kernel_size=3, bias=False, activation=True, norm=False)
self.process = block(filters // 2, filters // 2, kernel_size=3, bias=False, activation=True, norm=True)
# input is the feature signal with shape (b, f, w, h)
# passthrough is the structure signal with shape (b, f/2, w*2, h*2)
# output is conjoined upsample with shape (b, f/2, w*2, h*2)
def forward(self, input, passthrough):
x = F.interpolate(input, scale_factor=2, mode="nearest")
x = self.decimate(x)
p = self.process_passthrough(passthrough)
x = self.conjoin(torch.cat([x, p], dim=1))
return self.process(x)

View File

@ -108,20 +108,20 @@ class Discriminator_VGG_PixLoss(nn.Module):
self.bn4_1 = nn.GroupNorm(8, nf * 8, affine=True) self.bn4_1 = nn.GroupNorm(8, nf * 8, affine=True)
self.reduce_1 = ConvGnLelu(nf * 8, nf * 4, bias=False) self.reduce_1 = ConvGnLelu(nf * 8, nf * 4, bias=False)
self.pix_loss_collapse = ConvGnLelu(nf * 4, 1, bias=False, gn=False, lelu=False) self.pix_loss_collapse = ConvGnLelu(nf * 4, 1, bias=False, norm=False, activation=False)
# Pyramid network: upsample with residuals and produce losses at multiple resolutions. # Pyramid network: upsample with residuals and produce losses at multiple resolutions.
self.up3_decimate = ConvGnLelu(nf * 8, nf * 8, kernel_size=3, bias=True, lelu=False) self.up3_decimate = ConvGnLelu(nf * 8, nf * 8, kernel_size=3, bias=True, activation=False)
self.up3_converge = ConvGnLelu(nf * 16, nf * 8, kernel_size=3, bias=False) self.up3_converge = ConvGnLelu(nf * 16, nf * 8, kernel_size=3, bias=False)
self.up3_proc = ConvGnLelu(nf * 8, nf * 8, bias=False) self.up3_proc = ConvGnLelu(nf * 8, nf * 8, bias=False)
self.up3_reduce = ConvGnLelu(nf * 8, nf * 4, bias=False) self.up3_reduce = ConvGnLelu(nf * 8, nf * 4, bias=False)
self.up3_pix = ConvGnLelu(nf * 4, 1, bias=False, gn=False, lelu=False) self.up3_pix = ConvGnLelu(nf * 4, 1, bias=False, norm=False, activation=False)
self.up2_decimate = ConvGnLelu(nf * 8, nf * 4, kernel_size=1, bias=True, lelu=False) self.up2_decimate = ConvGnLelu(nf * 8, nf * 4, kernel_size=1, bias=True, activation=False)
self.up2_converge = ConvGnLelu(nf * 8, nf * 4, kernel_size=3, bias=False) self.up2_converge = ConvGnLelu(nf * 8, nf * 4, kernel_size=3, bias=False)
self.up2_proc = ConvGnLelu(nf * 4, nf * 4, bias=False) self.up2_proc = ConvGnLelu(nf * 4, nf * 4, bias=False)
self.up2_reduce = ConvGnLelu(nf * 4, nf * 2, bias=False) self.up2_reduce = ConvGnLelu(nf * 4, nf * 2, bias=False)
self.up2_pix = ConvGnLelu(nf * 2, 1, bias=False, gn=False, lelu=False) self.up2_pix = ConvGnLelu(nf * 2, 1, bias=False, norm=False, activation=False)
# activation function # activation function
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)