From 6ac6c951777799106071e22c1b4032b2a5e47e5a Mon Sep 17 00:00:00 2001
From: James Betker <jbetker@gmail.com>
Date: Wed, 1 Jul 2020 16:42:27 -0600
Subject: [PATCH] Fix scaling bug

---
 .../archs/SwitchedResidualGenerator_arch.py      | 16 +++++++---------
 1 file changed, 7 insertions(+), 9 deletions(-)

diff --git a/codes/models/archs/SwitchedResidualGenerator_arch.py b/codes/models/archs/SwitchedResidualGenerator_arch.py
index b2dc8faf..faae11c4 100644
--- a/codes/models/archs/SwitchedResidualGenerator_arch.py
+++ b/codes/models/archs/SwitchedResidualGenerator_arch.py
@@ -192,8 +192,10 @@ class ConfigurableSwitchComputer(nn.Module):
 
         # And the switch itself, including learned scalars
         self.switch = BareConvSwitch(initial_temperature=init_temp)
+        self.switch_scale = nn.Parameter(torch.full((1,), float(init_scalar)))
         self.post_switch_conv = ConvBnLelu(base_filters, base_filters, bn=False, bias=False)
-        self.scale = nn.Parameter(torch.full((1,), float(init_scalar)))
+        # The post_switch_conv gets a near-zero scale. The network can decide to magnify it (or not) depending on its needs.
+        self.psc_scale = nn.Parameter(torch.full((1,), float(1e-3)))
         self.bias = nn.Parameter(torch.zeros(1))
 
     def forward(self, x, output_attention_weights=False):
@@ -211,9 +213,9 @@ class ConfigurableSwitchComputer(nn.Module):
         m = F.interpolate(m, size=x.shape[2:], mode='nearest')
 
         outputs, attention = self.switch(xformed, m, True)
-        outputs = identity + outputs
-        #outputs = identity + self.post_switch_conv(outputs)
-        outputs = outputs * self.scale + self.bias
+        outputs = identity + outputs * self.switch_scale
+        outputs = identity + self.post_switch_conv(outputs) * self.psc_scale
+        outputs = outputs + self.bias
         if output_attention_weights:
             return outputs, attention
         else:
@@ -361,11 +363,7 @@ class ConfigurableSwitchedResidualGenerator2(nn.Module):
             switches.append(ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
                                                        functools.partial(MultiConvBlock, transformation_filters, transformation_filters, transformation_filters, kernel_size=kernel, depth=layers),
                                                        trans_count, initial_temp, enable_negative_transforms=enable_negative_transforms,
-                                                       add_scalable_noise_to_transforms=add_scalable_noise_to_transforms, init_scalar=1))
-
-        initialize_weights(switches, 1)
-        # Initialize the transforms with a lesser weight, since they are repeatedly added on to the resultant image.
-        initialize_weights([s.transforms for s in switches], .2 / len(switches))
+                                                       add_scalable_noise_to_transforms=add_scalable_noise_to_transforms, init_scalar=.01))
 
         self.switches = nn.ModuleList(switches)
         self.transformation_counts = trans_counts