diff --git a/codes/models/archs/SwitchedResidualGenerator_arch.py b/codes/models/archs/SwitchedResidualGenerator_arch.py index f28b0d73..c435c34a 100644 --- a/codes/models/archs/SwitchedResidualGenerator_arch.py +++ b/codes/models/archs/SwitchedResidualGenerator_arch.py @@ -361,7 +361,12 @@ class ConfigurableSwitchedResidualGenerator2(nn.Module): switches.append(ConfigurableSwitchComputer(transformation_filters, multiplx_fn, functools.partial(MultiConvBlock, transformation_filters, transformation_filters, transformation_filters, kernel_size=kernel, depth=layers), trans_count, initial_temp, enable_negative_transforms=enable_negative_transforms, - add_scalable_noise_to_transforms=add_scalable_noise_to_transforms, init_scalar=.01)) + add_scalable_noise_to_transforms=add_scalable_noise_to_transforms, init_scalar=1)) + + initialize_weights(switches, 1) + # Initialize the transforms with a lesser weight, since they are repeatedly added on to the resultant image. + initialize_weights([s.transforms for s in switches], .2 / len(switches)) + self.switches = nn.ModuleList(switches) self.transformation_counts = trans_counts self.init_temperature = initial_temp