From 2f2f87bbea33c93a5d3ff912cabbd7336c384ad7 Mon Sep 17 00:00:00 2001
From: James Betker <jbetker@gmail.com>
Date: Tue, 5 Jan 2021 20:14:39 -0700
Subject: [PATCH] Styled SR fixes

---
 codes/models/styled_sr/styled_sr.py      | 13 +++++++++----
 codes/models/styled_sr/stylegan2_base.py |  4 ++--
 2 files changed, 11 insertions(+), 6 deletions(-)

diff --git a/codes/models/styled_sr/styled_sr.py b/codes/models/styled_sr/styled_sr.py
index c40bb00b..07447dd7 100644
--- a/codes/models/styled_sr/styled_sr.py
+++ b/codes/models/styled_sr/styled_sr.py
@@ -130,9 +130,11 @@ class StyledSrGenerator(nn.Module):
         # Assume the vectorizer doesnt need transfer_mode=True. Re-evaluate this later.
         self.vectorizer = StyleVectorizer(latent_dim, style_depth, lr_mul=lr_mlp, transfer_mode=False)
         self.gen = Generator(image_size=image_size, latent_dim=latent_dim, initial_stride=initial_stride, transfer_mode=transfer_mode)
+        self.l2 = nn.MSELoss()
         self.mixed_prob = .9
         self._init_weights()
         self.transfer_mode = transfer_mode
+        self.initial_stride = initial_stride
         if transfer_mode:
             for p in self.parameters():
                 if not hasattr(p, 'FOR_TRANSFER_LEARNING'):
@@ -174,11 +176,14 @@ class StyledSrGenerator(nn.Module):
 
         out = self.gen(x, w_styles)
 
-        # Compute the net, areal, pixel-wise additions made on top of the LR image.
-        out_down = F.interpolate(out, size=(x.shape[-2], x.shape[-1]), mode="area")
-        diff = torch.sum(torch.abs(out_down - x), dim=[1,2,3])
+        # Compute an L2 loss on the areal interpolation of the generated image back down to LR * initial_stride; used
+        # for regularization.
+        out_down = F.interpolate(out, size=(x.shape[-2] // self.initial_stride, x.shape[-1] // self.initial_stride), mode="area")
+        if self.initial_stride > 1:
+            x = F.interpolate(x, scale_factor=1/self.initial_stride, mode="area")
+        l2_reg = self.l2(x, out_down)
 
-        return out, diff, w_styles
+        return out, l2_reg, w_styles
 
 
 if __name__ == '__main__':
diff --git a/codes/models/styled_sr/stylegan2_base.py b/codes/models/styled_sr/stylegan2_base.py
index 2738725a..dff9be14 100644
--- a/codes/models/styled_sr/stylegan2_base.py
+++ b/codes/models/styled_sr/stylegan2_base.py
@@ -396,8 +396,8 @@ class GeneratorBlock(nn.Module):
             x = self.upsample(x)
 
         inoise = inoise[:, :x.shape[2], :x.shape[3], :]
-        noise1 = self.to_noise1(inoise).permute((0, 3, 2, 1))
-        noise2 = self.to_noise2(inoise).permute((0, 3, 2, 1))
+        noise1 = self.to_noise1(inoise).permute((0, 3, 1, 2))
+        noise2 = self.to_noise2(inoise).permute((0, 3, 1, 2))
 
         style1 = self.to_style1(istyle)
         x = self.conv1(x, style1)