From d0957bd7d425492d46c36c75dace00eaa4f7f0b5 Mon Sep 17 00:00:00 2001 From: James Betker Date: Sun, 5 Jul 2020 17:32:46 -0600 Subject: [PATCH] Alter weight initialization for transformation blocks --- codes/models/archs/SwitchedResidualGenerator_arch.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/codes/models/archs/SwitchedResidualGenerator_arch.py b/codes/models/archs/SwitchedResidualGenerator_arch.py index cf102f40..a147510d 100644 --- a/codes/models/archs/SwitchedResidualGenerator_arch.py +++ b/codes/models/archs/SwitchedResidualGenerator_arch.py @@ -161,7 +161,9 @@ class ConfigurableSwitchedResidualGenerator2(nn.Module): pre_transform_block=functools.partial(ConvBnLelu, transformation_filters, transformation_filters, bn=False, bias=False), transform_block=functools.partial(MultiConvBlock, transformation_filters, transformation_filters, transformation_filters, kernel_size=kernel, depth=layers), transform_count=trans_count, init_temp=initial_temp, enable_negative_transforms=enable_negative_transforms, - add_scalable_noise_to_transforms=add_scalable_noise_to_transforms, init_scalar=.01)) + add_scalable_noise_to_transforms=add_scalable_noise_to_transforms, init_scalar=1)) + # Initialize the transforms with a lesser weight, since they are repeatedly added on to the resultant image. + initialize_weights([s.transforms for s in switches], .2 / len(switches)) self.switches = nn.ModuleList(switches) self.transformation_counts = trans_counts