Re-work SwitchedResgen2

Got rid of the converged multiplexer bases but kept the configurable architecture. The
new multiplexers look a lot like the old one.

Took some queues from the transformer architecture: translate image to a higher filter-space
and stay there for the duration of the models computation. Also perform convs after each
switch to allow the model to anneal issues that arise.
This commit is contained in:
James Betker 2020-06-25 18:17:05 -06:00
parent 42a10b34ce
commit 407224eba1

View File

@ -175,9 +175,32 @@ class ConfigurableSwitchComputer(nn.Module):
self.switch.set_attention_temperature(temp) self.switch.set_attention_temperature(temp)
class ResidualBasisMultiplexerBase(nn.Module): class ConvBasisMultiplexer(nn.Module):
def __init__(self, input_channels, base_filters, growth, reductions, processing_depth, multiplexer_channels, use_bn=True):
super(ConvBasisMultiplexer, self).__init__()
self.filter_conv = ConvBnLelu(input_channels, base_filters)
self.reduction_blocks = nn.Sequential(OrderedDict([('block%i:' % (i,), HalvingProcessingBlock(base_filters * 2 ** i)) for i in range(reductions)]))
reduction_filters = base_filters * 2 ** reductions
self.processing_blocks, self.output_filter_count = create_sequential_growing_processing_block(reduction_filters, growth, processing_depth)
gap = self.output_filter_count - multiplexer_channels
self.cbl1 = ConvBnLelu(self.output_filter_count, self.output_filter_count - (gap // 4), bn=use_bn)
self.cbl2 = ConvBnLelu(self.output_filter_count - (gap // 4), self.output_filter_count - (gap // 2), bn=use_bn)
self.cbl3 = ConvBnLelu(self.output_filter_count - (gap // 2), multiplexer_channels)
def forward(self, x):
x = self.filter_conv(x)
x = self.reduction_blocks(x)
x = self.processing_blocks(x)
x = self.cbl1(x)
x = self.cbl2(x)
x = self.cbl3(x)
return x
class ConvBasisMultiplexerBase(nn.Module):
def __init__(self, input_channels, base_filters, growth, reductions, processing_depth): def __init__(self, input_channels, base_filters, growth, reductions, processing_depth):
super(ResidualBasisMultiplexerBase, self).__init__() super(ConvBasisMultiplexerBase, self).__init__()
self.filter_conv = ConvBnLelu(input_channels, base_filters) self.filter_conv = ConvBnLelu(input_channels, base_filters)
self.reduction_blocks = nn.Sequential(OrderedDict([('block%i:' % (i,), HalvingProcessingBlock(base_filters * 2 ** i)) for i in range(reductions)])) self.reduction_blocks = nn.Sequential(OrderedDict([('block%i:' % (i,), HalvingProcessingBlock(base_filters * 2 ** i)) for i in range(reductions)]))
reduction_filters = base_filters * 2 ** reductions reduction_filters = base_filters * 2 ** reductions
@ -190,9 +213,9 @@ class ResidualBasisMultiplexerBase(nn.Module):
return x return x
class ResidualBasisMultiplexerLeaf(nn.Module): class ConvBasisMultiplexerLeaf(nn.Module):
def __init__(self, base, filters, multiplexer_channels, use_bn=False): def __init__(self, base, filters, multiplexer_channels, use_bn=False):
super(ResidualBasisMultiplexerLeaf, self).__init__() super(ConvBasisMultiplexerLeaf, self).__init__()
assert(filters > multiplexer_channels) assert(filters > multiplexer_channels)
gap = filters - multiplexer_channels gap = filters - multiplexer_channels
assert(gap % 4 == 0) assert(gap % 4 == 0)
@ -277,19 +300,24 @@ class ConfigurableSwitchedResidualGenerator(nn.Module):
class ConfigurableSwitchedResidualGenerator2(nn.Module): class ConfigurableSwitchedResidualGenerator2(nn.Module):
def __init__(self, switch_filters, switch_growths, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes, def __init__(self, switch_filters, switch_growths, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes,
trans_layers, trans_filters_mid, initial_temp=20, final_temperature_step=50000, heightened_temp_min=1, trans_layers, transformation_filters, initial_temp=20, final_temperature_step=50000, heightened_temp_min=1,
heightened_final_step=50000, upsample_factor=1, enable_negative_transforms=False, heightened_final_step=50000, upsample_factor=1, enable_negative_transforms=False,
add_scalable_noise_to_transforms=False): add_scalable_noise_to_transforms=False):
super(ConfigurableSwitchedResidualGenerator2, self).__init__() super(ConfigurableSwitchedResidualGenerator2, self).__init__()
switches = [] switches = []
multiplexer_base = ResidualBasisMultiplexerBase(3, switch_filters[0], switch_growths[0], switch_reductions[0], switch_processing_layers[0]) post_switch_proc = []
for trans_count, kernel, layers, mid_filters in zip(trans_counts, trans_kernel_sizes, trans_layers, trans_filters_mid): self.initial_conv = ConvBnLelu(3, transformation_filters, bn=False)
leaf_fn = functools.partial(ResidualBasisMultiplexerLeaf, multiplexer_base, multiplexer_base.output_filter_count) self.final_conv = ConvBnLelu(transformation_filters, 3, bn=False)
switches.append(ConfigurableSwitchComputer(leaf_fn, functools.partial(ResidualBranch, 3, mid_filters, 3, kernel_size=kernel, depth=layers), trans_count, initial_temp, enable_negative_transforms=enable_negative_transforms, add_scalable_noise_to_transforms=add_scalable_noise_to_transforms)) for filters, growth, sw_reduce, sw_proc, trans_count, kernel, layers in zip(switch_filters, switch_growths, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes, trans_layers):
multiplx_fn = functools.partial(ConvBasisMultiplexer, transformation_filters, filters, growth, sw_reduce, sw_proc, trans_count)
switches.append(ConfigurableSwitchComputer(multiplx_fn, functools.partial(ResidualBranch, transformation_filters, transformation_filters, transformation_filters, kernel_size=kernel, depth=layers), trans_count, initial_temp, enable_negative_transforms=enable_negative_transforms, add_scalable_noise_to_transforms=add_scalable_noise_to_transforms))
post_switch_proc.append(ConvBnLelu(transformation_filters, transformation_filters, bn=False))
initialize_weights(switches, 1) initialize_weights(switches, 1)
# Initialize the transforms with a lesser weight, since they are repeatedly added on to the resultant image. # Initialize the transforms with a lesser weight, since they are repeatedly added on to the resultant image.
initialize_weights([s.transforms for s in switches], .2 / len(switches)) initialize_weights([s.transforms for s in switches], .2 / len(switches))
self.switches = nn.ModuleList(switches) self.switches = nn.ModuleList(switches)
initialize_weights([p for p in post_switch_proc], .01)
self.post_switch_convs = nn.ModuleList(post_switch_proc)
self.transformation_counts = trans_counts self.transformation_counts = trans_counts
self.init_temperature = initial_temp self.init_temperature = initial_temp
self.final_temperature_step = final_temperature_step self.final_temperature_step = final_temperature_step
@ -304,11 +332,16 @@ class ConfigurableSwitchedResidualGenerator2(nn.Module):
if self.upsample_factor > 1: if self.upsample_factor > 1:
x = F.interpolate(x, scale_factor=self.upsample_factor, mode="nearest") x = F.interpolate(x, scale_factor=self.upsample_factor, mode="nearest")
x = self.initial_conv(x)
self.attentions = [] self.attentions = []
for i, sw in enumerate(self.switches): for i, (sw, conv) in enumerate(zip(self.switches, self.post_switch_convs)):
sw_out, att = sw.forward(x, True) sw_out, att = sw.forward(x, True)
x = x + sw_out
self.attentions.append(att) self.attentions.append(att)
x = x + sw_out
x = x + conv(x)
x = self.final_conv(x)
return x, return x,
def set_temperature(self, temp): def set_temperature(self, temp):