From 75f148022deff3969988a9eec43602fff412290a Mon Sep 17 00:00:00 2001
From: James Betker <jbetker@gmail.com>
Date: Tue, 30 Jun 2020 13:52:47 -0600
Subject: [PATCH] Even more NSG improvements (r4)

---
 codes/models/archs/NestedSwitchGenerator.py   | 21 ++++++++++---------
 .../archs/SwitchedResidualGenerator_arch.py   |  6 +++---
 2 files changed, 14 insertions(+), 13 deletions(-)

diff --git a/codes/models/archs/NestedSwitchGenerator.py b/codes/models/archs/NestedSwitchGenerator.py
index 4724e421..056ae513 100644
--- a/codes/models/archs/NestedSwitchGenerator.py
+++ b/codes/models/archs/NestedSwitchGenerator.py
@@ -75,8 +75,8 @@ class Switch(nn.Module):
         self.bias = nn.Parameter(torch.zeros(1))
 
         if not self.pass_chain_forward:
-            self.c_constric = MultiConvBlock(32, 32, 16, 3, 3)
-            self.c_conjoin = ConvBnLelu(32, 16, kernel_size=1, bn=False)
+            self.parameterize = ConvBnLelu(16, 16, bn=False, lelu=False)
+            self.c_constric = MultiConvBlock(48, 32, 16, kernel_size=5, depth=3, bn=False)
 
     # x is the input fed to the transform blocks.
     # m is the output of the multiplexer which will be used to select from those transform blocks.
@@ -91,11 +91,9 @@ class Switch(nn.Module):
             # out in a normal distribution.
             context = (chain[-1] - 6) / 9.4
             context = F.pixel_shuffle(context, 4)
-            context = self.c_constric(context)
-
             context = F.interpolate(context, size=x.shape[2:], mode='nearest')
-            context = torch.cat([x, context], dim=1)
-            context = self.c_conjoin(context)
+            context = torch.cat([self.parameterize(x), context], dim=1)
+            context = self.c_constric(context) / 1.6
 
             if self.add_noise:
                 rand_feature = torch.randn_like(x)
@@ -224,6 +222,7 @@ class NestedSwitchComputer(nn.Module):
         nn.init.kaiming_normal_(self.multiplexer_init_conv.weight, nonlinearity="relu")
 
     def forward(self, x):
+        feed_forward = x
         trunk = []
         trunk_input = self.multiplexer_init_conv(x)
         for m in self.processing_trunk:
@@ -232,7 +231,8 @@ class NestedSwitchComputer(nn.Module):
 
         self.trunk = (trunk[-1] - 6) / 9.4
         x, att = self.switch.forward(x, trunk)
-        return self.anneal(x), att
+        x = x + feed_forward
+        return feed_forward + self.anneal(x) / .86, att
 
     def set_temperature(self, temp):
         self.switch.set_temperature(temp)
@@ -244,6 +244,7 @@ class NestedSwitchedGenerator(nn.Module):
                  heightened_final_step=50000, upsample_factor=1, add_scalable_noise_to_transforms=False):
         super(NestedSwitchedGenerator, self).__init__()
         self.initial_conv = ConvBnLelu(3, transformation_filters, kernel_size=7, lelu=False, bn=False)
+        self.proc_conv = ConvBnLelu(transformation_filters, transformation_filters, bn=False)
         self.final_conv = ConvBnLelu(transformation_filters, 3, kernel_size=1, lelu=False, bn=False)
 
         switches = []
@@ -271,12 +272,12 @@ class NestedSwitchedGenerator(nn.Module):
 
         self.attentions = []
         for i, sw in enumerate(self.switches):
-            sw_out, att = sw.forward(x)
+            x, att = sw.forward(x)
             self.attentions.append(att)
-            x = x + sw_out
 
+        x = self.proc_conv(x) / .85
         x = self.final_conv(x)
-        return x,
+        return x / 4.26,
 
     def set_temperature(self, temp):
         [sw.set_temperature(temp) for sw in self.switches]
diff --git a/codes/models/archs/SwitchedResidualGenerator_arch.py b/codes/models/archs/SwitchedResidualGenerator_arch.py
index fbf52dad..63f36aac 100644
--- a/codes/models/archs/SwitchedResidualGenerator_arch.py
+++ b/codes/models/archs/SwitchedResidualGenerator_arch.py
@@ -43,12 +43,12 @@ class ConvBnLelu(nn.Module):
 
 
 class MultiConvBlock(nn.Module):
-    def __init__(self, filters_in, filters_mid, filters_out, kernel_size, depth, scale_init=1):
+    def __init__(self, filters_in, filters_mid, filters_out, kernel_size, depth, scale_init=1, bn=False):
         assert depth >= 2
         super(MultiConvBlock, self).__init__()
         self.noise_scale = nn.Parameter(torch.full((1,), fill_value=.01))
-        self.bnconvs = nn.ModuleList([ConvBnLelu(filters_in, filters_mid, kernel_size, bn=False)] +
-                                     [ConvBnLelu(filters_mid, filters_mid, kernel_size, bn=False) for i in range(depth-2)] +
+        self.bnconvs = nn.ModuleList([ConvBnLelu(filters_in, filters_mid, kernel_size, bn=bn)] +
+                                     [ConvBnLelu(filters_mid, filters_mid, kernel_size, bn=bn) for i in range(depth-2)] +
                                      [ConvBnLelu(filters_mid, filters_out, kernel_size, lelu=False, bn=False)])
         self.scale = nn.Parameter(torch.full((1,), fill_value=scale_init))
         self.bias = nn.Parameter(torch.zeros(1))