From 17191de836dab9eeea1ed3ba854fab2ddde1fb3d Mon Sep 17 00:00:00 2001
From: James Betker <jbetker@gmail.com>
Date: Wed, 1 Jul 2020 15:57:55 -0600
Subject: [PATCH] Experiment: bring initialize_weights back again

Something really strange going on here..
---
 codes/models/archs/SwitchedResidualGenerator_arch.py | 7 ++++++-
 1 file changed, 6 insertions(+), 1 deletion(-)

diff --git a/codes/models/archs/SwitchedResidualGenerator_arch.py b/codes/models/archs/SwitchedResidualGenerator_arch.py
index f28b0d73..c435c34a 100644
--- a/codes/models/archs/SwitchedResidualGenerator_arch.py
+++ b/codes/models/archs/SwitchedResidualGenerator_arch.py
@@ -361,7 +361,12 @@ class ConfigurableSwitchedResidualGenerator2(nn.Module):
             switches.append(ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
                                                        functools.partial(MultiConvBlock, transformation_filters, transformation_filters, transformation_filters, kernel_size=kernel, depth=layers),
                                                        trans_count, initial_temp, enable_negative_transforms=enable_negative_transforms,
-                                                       add_scalable_noise_to_transforms=add_scalable_noise_to_transforms, init_scalar=.01))
+                                                       add_scalable_noise_to_transforms=add_scalable_noise_to_transforms, init_scalar=1))
+
+        initialize_weights(switches, 1)
+        # Initialize the transforms with a lesser weight, since they are repeatedly added on to the resultant image.
+        initialize_weights([s.transforms for s in switches], .2 / len(switches))
+
         self.switches = nn.ModuleList(switches)
         self.transformation_counts = trans_counts
         self.init_temperature = initial_temp