Alter weight initialization for transformation blocks

This commit is contained in:
James Betker 2020-07-05 17:32:46 -06:00
parent 16d1bf6dd7
commit d0957bd7d4

View File

@ -161,7 +161,9 @@ class ConfigurableSwitchedResidualGenerator2(nn.Module):
pre_transform_block=functools.partial(ConvBnLelu, transformation_filters, transformation_filters, bn=False, bias=False), pre_transform_block=functools.partial(ConvBnLelu, transformation_filters, transformation_filters, bn=False, bias=False),
transform_block=functools.partial(MultiConvBlock, transformation_filters, transformation_filters, transformation_filters, kernel_size=kernel, depth=layers), transform_block=functools.partial(MultiConvBlock, transformation_filters, transformation_filters, transformation_filters, kernel_size=kernel, depth=layers),
transform_count=trans_count, init_temp=initial_temp, enable_negative_transforms=enable_negative_transforms, transform_count=trans_count, init_temp=initial_temp, enable_negative_transforms=enable_negative_transforms,
add_scalable_noise_to_transforms=add_scalable_noise_to_transforms, init_scalar=.01)) add_scalable_noise_to_transforms=add_scalable_noise_to_transforms, init_scalar=1))
# Initialize the transforms with a lesser weight, since they are repeatedly added on to the resultant image.
initialize_weights([s.transforms for s in switches], .2 / len(switches))
self.switches = nn.ModuleList(switches) self.switches = nn.ModuleList(switches)
self.transformation_counts = trans_counts self.transformation_counts = trans_counts