Experiment: revert initialization changes

This commit is contained in:
James Betker 2020-07-01 12:08:09 -06:00
parent 78276afcaa
commit e2398ac83c

View File

@ -116,7 +116,7 @@ def create_sequential_growing_processing_block(filters_init, filter_growth, num_
convs = [] convs = []
current_filters = filters_init current_filters = filters_init
for i in range(num_convs): for i in range(num_convs):
convs.append(ConvBnLelu(current_filters, current_filters + filter_growth, bn=True, bias=False)) convs.append(ConvBnRelu(current_filters, current_filters + filter_growth, bn=True, bias=False))
current_filters += filter_growth current_filters += filter_growth
return nn.Sequential(*convs), current_filters return nn.Sequential(*convs), current_filters
@ -222,15 +222,15 @@ class ConfigurableSwitchComputer(nn.Module):
class ConvBasisMultiplexer(nn.Module): class ConvBasisMultiplexer(nn.Module):
def __init__(self, input_channels, base_filters, growth, reductions, processing_depth, multiplexer_channels, use_bn=True): def __init__(self, input_channels, base_filters, growth, reductions, processing_depth, multiplexer_channels, use_bn=True):
super(ConvBasisMultiplexer, self).__init__() super(ConvBasisMultiplexer, self).__init__()
self.filter_conv = ConvBnLelu(input_channels, base_filters, bias=True) self.filter_conv = ConvBnRelu(input_channels, base_filters, bias=True)
self.reduction_blocks = nn.Sequential(OrderedDict([('block%i:' % (i,), HalvingProcessingBlock(base_filters * 2 ** i)) for i in range(reductions)])) self.reduction_blocks = nn.Sequential(OrderedDict([('block%i:' % (i,), HalvingProcessingBlock(base_filters * 2 ** i)) for i in range(reductions)]))
reduction_filters = base_filters * 2 ** reductions reduction_filters = base_filters * 2 ** reductions
self.processing_blocks, self.output_filter_count = create_sequential_growing_processing_block(reduction_filters, growth, processing_depth) self.processing_blocks, self.output_filter_count = create_sequential_growing_processing_block(reduction_filters, growth, processing_depth)
gap = self.output_filter_count - multiplexer_channels gap = self.output_filter_count - multiplexer_channels
self.cbl1 = ConvBnLelu(self.output_filter_count, self.output_filter_count - (gap // 2), bn=use_bn, bias=False) self.cbl1 = ConvBnRelu(self.output_filter_count, self.output_filter_count - (gap // 2), bn=use_bn, bias=False)
self.cbl2 = ConvBnLelu(self.output_filter_count - (gap // 2), self.output_filter_count - (3 * gap // 4), bn=use_bn, bias=False) self.cbl2 = ConvBnRelu(self.output_filter_count - (gap // 2), self.output_filter_count - (3 * gap // 4), bn=use_bn, bias=False)
self.cbl3 = ConvBnLelu(self.output_filter_count - (3 * gap // 4), multiplexer_channels, bias=True) self.cbl3 = ConvBnRelu(self.output_filter_count - (3 * gap // 4), multiplexer_channels, bias=True)
def forward(self, x): def forward(self, x):
x = self.filter_conv(x) x = self.filter_conv(x)
@ -351,14 +351,17 @@ class ConfigurableSwitchedResidualGenerator2(nn.Module):
switches = [] switches = []
post_switch_proc = [] post_switch_proc = []
self.initial_conv = ConvBnLelu(3, transformation_filters, bn=False) self.initial_conv = ConvBnLelu(3, transformation_filters, bn=False)
self.proc_conv = ConvBnLelu(transformation_filters, transformation_filters, bn=False, bias=False) self.proc_conv = ConvBnLelu(transformation_filters, transformation_filters, bn=False)
self.final_conv = ConvBnLelu(transformation_filters, 3, bn=False, lelu=False) self.final_conv = ConvBnLelu(transformation_filters, 3, bn=False, lelu=False)
for filters, growth, sw_reduce, sw_proc, trans_count, kernel, layers in zip(switch_filters, switch_growths, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes, trans_layers): for filters, growth, sw_reduce, sw_proc, trans_count, kernel, layers in zip(switch_filters, switch_growths, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes, trans_layers):
multiplx_fn = functools.partial(ConvBasisMultiplexer, transformation_filters, filters, growth, sw_reduce, sw_proc, trans_count) multiplx_fn = functools.partial(ConvBasisMultiplexer, transformation_filters, filters, growth, sw_reduce, sw_proc, trans_count)
switches.append(ConfigurableSwitchComputer(multiplx_fn, functools.partial(MultiConvBlock, transformation_filters, transformation_filters, transformation_filters, kernel_size=kernel, depth=layers), trans_count, initial_temp, enable_negative_transforms=enable_negative_transforms, add_scalable_noise_to_transforms=add_scalable_noise_to_transforms)) switches.append(ConfigurableSwitchComputer(multiplx_fn, functools.partial(MultiConvBlock, transformation_filters, transformation_filters, transformation_filters, kernel_size=kernel, depth=layers), trans_count, initial_temp, enable_negative_transforms=enable_negative_transforms, add_scalable_noise_to_transforms=add_scalable_noise_to_transforms))
post_switch_proc.append(ConvBnLelu(transformation_filters, transformation_filters, bn=False)) post_switch_proc.append(ConvBnLelu(transformation_filters, transformation_filters, bn=False))
initialize_weights(switches, 1)
# Initialize the transforms with a lesser weight, since they are repeatedly added on to the resultant image. # Initialize the transforms with a lesser weight, since they are repeatedly added on to the resultant image.
initialize_weights([s.transforms for s in switches], .2 / len(switches))
self.switches = nn.ModuleList(switches) self.switches = nn.ModuleList(switches)
initialize_weights([p for p in post_switch_proc], .01)
self.post_switch_convs = nn.ModuleList(post_switch_proc) self.post_switch_convs = nn.ModuleList(post_switch_proc)
self.transformation_counts = trans_counts self.transformation_counts = trans_counts
self.init_temperature = initial_temp self.init_temperature = initial_temp
@ -384,7 +387,7 @@ class ConfigurableSwitchedResidualGenerator2(nn.Module):
x = self.proc_conv(x) x = self.proc_conv(x)
x = self.final_conv(x) x = self.final_conv(x)
return x / 13, return x,
def set_temperature(self, temp): def set_temperature(self, temp):
[sw.set_temperature(temp) for sw in self.switches] [sw.set_temperature(temp) for sw in self.switches]