diff --git a/codes/models/archs/SwitchedResidualGenerator_arch.py b/codes/models/archs/SwitchedResidualGenerator_arch.py index b2dc8faf..faae11c4 100644 --- a/codes/models/archs/SwitchedResidualGenerator_arch.py +++ b/codes/models/archs/SwitchedResidualGenerator_arch.py @@ -192,8 +192,10 @@ class ConfigurableSwitchComputer(nn.Module): # And the switch itself, including learned scalars self.switch = BareConvSwitch(initial_temperature=init_temp) + self.switch_scale = nn.Parameter(torch.full((1,), float(init_scalar))) self.post_switch_conv = ConvBnLelu(base_filters, base_filters, bn=False, bias=False) - self.scale = nn.Parameter(torch.full((1,), float(init_scalar))) + # The post_switch_conv gets a near-zero scale. The network can decide to magnify it (or not) depending on its needs. + self.psc_scale = nn.Parameter(torch.full((1,), float(1e-3))) self.bias = nn.Parameter(torch.zeros(1)) def forward(self, x, output_attention_weights=False): @@ -211,9 +213,9 @@ class ConfigurableSwitchComputer(nn.Module): m = F.interpolate(m, size=x.shape[2:], mode='nearest') outputs, attention = self.switch(xformed, m, True) - outputs = identity + outputs - #outputs = identity + self.post_switch_conv(outputs) - outputs = outputs * self.scale + self.bias + outputs = identity + outputs * self.switch_scale + outputs = identity + self.post_switch_conv(outputs) * self.psc_scale + outputs = outputs + self.bias if output_attention_weights: return outputs, attention else: @@ -361,11 +363,7 @@ class ConfigurableSwitchedResidualGenerator2(nn.Module): switches.append(ConfigurableSwitchComputer(transformation_filters, multiplx_fn, functools.partial(MultiConvBlock, transformation_filters, transformation_filters, transformation_filters, kernel_size=kernel, depth=layers), trans_count, initial_temp, enable_negative_transforms=enable_negative_transforms, - add_scalable_noise_to_transforms=add_scalable_noise_to_transforms, init_scalar=1)) - - initialize_weights(switches, 1) - # Initialize the transforms with a lesser weight, since they are repeatedly added on to the resultant image. - initialize_weights([s.transforms for s in switches], .2 / len(switches)) + add_scalable_noise_to_transforms=add_scalable_noise_to_transforms, init_scalar=.01)) self.switches = nn.ModuleList(switches) self.transformation_counts = trans_counts