Re-work SwitchedResgen2
Got rid of the converged multiplexer bases but kept the configurable architecture. The new multiplexers look a lot like the old one. Took some queues from the transformer architecture: translate image to a higher filter-space and stay there for the duration of the models computation. Also perform convs after each switch to allow the model to anneal issues that arise.
This commit is contained in:
parent
42a10b34ce
commit
407224eba1
|
@ -175,9 +175,32 @@ class ConfigurableSwitchComputer(nn.Module):
|
|||
self.switch.set_attention_temperature(temp)
|
||||
|
||||
|
||||
class ResidualBasisMultiplexerBase(nn.Module):
|
||||
class ConvBasisMultiplexer(nn.Module):
|
||||
def __init__(self, input_channels, base_filters, growth, reductions, processing_depth, multiplexer_channels, use_bn=True):
|
||||
super(ConvBasisMultiplexer, self).__init__()
|
||||
self.filter_conv = ConvBnLelu(input_channels, base_filters)
|
||||
self.reduction_blocks = nn.Sequential(OrderedDict([('block%i:' % (i,), HalvingProcessingBlock(base_filters * 2 ** i)) for i in range(reductions)]))
|
||||
reduction_filters = base_filters * 2 ** reductions
|
||||
self.processing_blocks, self.output_filter_count = create_sequential_growing_processing_block(reduction_filters, growth, processing_depth)
|
||||
|
||||
gap = self.output_filter_count - multiplexer_channels
|
||||
self.cbl1 = ConvBnLelu(self.output_filter_count, self.output_filter_count - (gap // 4), bn=use_bn)
|
||||
self.cbl2 = ConvBnLelu(self.output_filter_count - (gap // 4), self.output_filter_count - (gap // 2), bn=use_bn)
|
||||
self.cbl3 = ConvBnLelu(self.output_filter_count - (gap // 2), multiplexer_channels)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.filter_conv(x)
|
||||
x = self.reduction_blocks(x)
|
||||
x = self.processing_blocks(x)
|
||||
x = self.cbl1(x)
|
||||
x = self.cbl2(x)
|
||||
x = self.cbl3(x)
|
||||
return x
|
||||
|
||||
|
||||
class ConvBasisMultiplexerBase(nn.Module):
|
||||
def __init__(self, input_channels, base_filters, growth, reductions, processing_depth):
|
||||
super(ResidualBasisMultiplexerBase, self).__init__()
|
||||
super(ConvBasisMultiplexerBase, self).__init__()
|
||||
self.filter_conv = ConvBnLelu(input_channels, base_filters)
|
||||
self.reduction_blocks = nn.Sequential(OrderedDict([('block%i:' % (i,), HalvingProcessingBlock(base_filters * 2 ** i)) for i in range(reductions)]))
|
||||
reduction_filters = base_filters * 2 ** reductions
|
||||
|
@ -190,9 +213,9 @@ class ResidualBasisMultiplexerBase(nn.Module):
|
|||
return x
|
||||
|
||||
|
||||
class ResidualBasisMultiplexerLeaf(nn.Module):
|
||||
class ConvBasisMultiplexerLeaf(nn.Module):
|
||||
def __init__(self, base, filters, multiplexer_channels, use_bn=False):
|
||||
super(ResidualBasisMultiplexerLeaf, self).__init__()
|
||||
super(ConvBasisMultiplexerLeaf, self).__init__()
|
||||
assert(filters > multiplexer_channels)
|
||||
gap = filters - multiplexer_channels
|
||||
assert(gap % 4 == 0)
|
||||
|
@ -277,19 +300,24 @@ class ConfigurableSwitchedResidualGenerator(nn.Module):
|
|||
|
||||
class ConfigurableSwitchedResidualGenerator2(nn.Module):
|
||||
def __init__(self, switch_filters, switch_growths, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes,
|
||||
trans_layers, trans_filters_mid, initial_temp=20, final_temperature_step=50000, heightened_temp_min=1,
|
||||
trans_layers, transformation_filters, initial_temp=20, final_temperature_step=50000, heightened_temp_min=1,
|
||||
heightened_final_step=50000, upsample_factor=1, enable_negative_transforms=False,
|
||||
add_scalable_noise_to_transforms=False):
|
||||
super(ConfigurableSwitchedResidualGenerator2, self).__init__()
|
||||
switches = []
|
||||
multiplexer_base = ResidualBasisMultiplexerBase(3, switch_filters[0], switch_growths[0], switch_reductions[0], switch_processing_layers[0])
|
||||
for trans_count, kernel, layers, mid_filters in zip(trans_counts, trans_kernel_sizes, trans_layers, trans_filters_mid):
|
||||
leaf_fn = functools.partial(ResidualBasisMultiplexerLeaf, multiplexer_base, multiplexer_base.output_filter_count)
|
||||
switches.append(ConfigurableSwitchComputer(leaf_fn, functools.partial(ResidualBranch, 3, mid_filters, 3, kernel_size=kernel, depth=layers), trans_count, initial_temp, enable_negative_transforms=enable_negative_transforms, add_scalable_noise_to_transforms=add_scalable_noise_to_transforms))
|
||||
post_switch_proc = []
|
||||
self.initial_conv = ConvBnLelu(3, transformation_filters, bn=False)
|
||||
self.final_conv = ConvBnLelu(transformation_filters, 3, bn=False)
|
||||
for filters, growth, sw_reduce, sw_proc, trans_count, kernel, layers in zip(switch_filters, switch_growths, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes, trans_layers):
|
||||
multiplx_fn = functools.partial(ConvBasisMultiplexer, transformation_filters, filters, growth, sw_reduce, sw_proc, trans_count)
|
||||
switches.append(ConfigurableSwitchComputer(multiplx_fn, functools.partial(ResidualBranch, transformation_filters, transformation_filters, transformation_filters, kernel_size=kernel, depth=layers), trans_count, initial_temp, enable_negative_transforms=enable_negative_transforms, add_scalable_noise_to_transforms=add_scalable_noise_to_transforms))
|
||||
post_switch_proc.append(ConvBnLelu(transformation_filters, transformation_filters, bn=False))
|
||||
initialize_weights(switches, 1)
|
||||
# Initialize the transforms with a lesser weight, since they are repeatedly added on to the resultant image.
|
||||
initialize_weights([s.transforms for s in switches], .2 / len(switches))
|
||||
self.switches = nn.ModuleList(switches)
|
||||
initialize_weights([p for p in post_switch_proc], .01)
|
||||
self.post_switch_convs = nn.ModuleList(post_switch_proc)
|
||||
self.transformation_counts = trans_counts
|
||||
self.init_temperature = initial_temp
|
||||
self.final_temperature_step = final_temperature_step
|
||||
|
@ -304,11 +332,16 @@ class ConfigurableSwitchedResidualGenerator2(nn.Module):
|
|||
if self.upsample_factor > 1:
|
||||
x = F.interpolate(x, scale_factor=self.upsample_factor, mode="nearest")
|
||||
|
||||
x = self.initial_conv(x)
|
||||
|
||||
self.attentions = []
|
||||
for i, sw in enumerate(self.switches):
|
||||
for i, (sw, conv) in enumerate(zip(self.switches, self.post_switch_convs)):
|
||||
sw_out, att = sw.forward(x, True)
|
||||
x = x + sw_out
|
||||
self.attentions.append(att)
|
||||
x = x + sw_out
|
||||
x = x + conv(x)
|
||||
|
||||
x = self.final_conv(x)
|
||||
return x,
|
||||
|
||||
def set_temperature(self, temp):
|
||||
|
|
Loading…
Reference in New Issue
Block a user