From 70c764b9d4508b19f062ac3859dce7d50c059a98 Mon Sep 17 00:00:00 2001
From: James Betker <jbetker@gmail.com>
Date: Tue, 16 Jun 2020 13:24:07 -0600
Subject: [PATCH] Create a configurable SwichedResidualGenerator

Also move attention image generator out of repo
---
 .../archs/SwitchedResidualGenerator_arch.py   | 81 +++++++++++--------
 codes/models/networks.py                      |  4 +
 2 files changed, 53 insertions(+), 32 deletions(-)

diff --git a/codes/models/archs/SwitchedResidualGenerator_arch.py b/codes/models/archs/SwitchedResidualGenerator_arch.py
index bad60da7..9a472cae 100644
--- a/codes/models/archs/SwitchedResidualGenerator_arch.py
+++ b/codes/models/archs/SwitchedResidualGenerator_arch.py
@@ -4,8 +4,7 @@ from switched_conv import BareConvSwitch, compute_attention_specificity
 import torch.nn.functional as F
 import functools
 from models.archs.arch_util import initialize_weights
-import torchvision
-from torchvision import transforms
+from switched_conv_util import save_attention_to_image
 
 
 class ConvBnLelu(nn.Module):
@@ -90,7 +89,6 @@ class SwitchComputer(nn.Module):
     def set_temperature(self, temp):
         self.switch.set_attention_temperature(temp)
 
-
 class SwitchedResidualGenerator(nn.Module):
     def __init__(self, switch_filters, initial_temp=20, final_temperature_step=50000):
         super(SwitchedResidualGenerator, self).__init__()
@@ -137,33 +135,55 @@ class SwitchedResidualGenerator(nn.Module):
         self.switch3.set_temperature(temp)
         self.switch4.set_temperature(temp)
 
-    # Copied from torchvision.utils.save_image. Allows specifying pixel format.
-    def save_image(self, tensor, fp, nrow=8, padding=2,
-                   normalize=False, range=None, scale_each=False, pad_value=0, format=None, pix_format=None):
-        from PIL import Image
-        grid = torchvision.utils.make_grid(tensor, nrow=nrow, padding=padding, pad_value=pad_value,
-                         normalize=normalize, range=range, scale_each=scale_each)
-        # Add 0.5 after unnormalizing to [0, 255] to round to nearest integer
-        ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
-        im = Image.fromarray(ndarr, mode=pix_format).convert('RGB')
-        im.save(fp, format=format)
+    def get_debug_values(self, step):
+        # Take the chance to update the temperature here.
+        temp = max(1, int(self.init_temperature * (self.final_temperature_step - step) / self.final_temperature_step))
+        self.set_temperature(temp)
 
-    def convert_attention_indices_to_image(self, attention_out, attention_size, step, fname_part="map", l_mult=1.0):
-        magnitude, indices = torch.topk(attention_out, 1, dim=-1)
-        magnitude = magnitude.squeeze(3)
-        indices = indices.squeeze(3)
-        # indices is an integer tensor (b,w,h) where values are on the range [0,attention_size]
-        # magnitude is a float tensor (b,w,h) [0,1] representing the magnitude of that attention.
-        # Use HSV colorspace to show this. Hue is mapped to the indices, Lightness is mapped to intensity,
-        # Saturation is left fixed.
-        hue = indices.float() / attention_size
-        saturation = torch.full_like(hue, .8)
-        value = magnitude * l_mult
-        hsv_img = torch.stack([hue, saturation, value], dim=1)
+        if step % 250 == 0:
+            save_attention_to_image(self.a1, 4, step, "a1")
+            save_attention_to_image(self.a2, 8, step, "a2")
+            save_attention_to_image(self.a3, 16, step, "a3", 2)
+            save_attention_to_image(self.a4, 32, step, "a4", 4)
 
-        import os
-        os.makedirs("attention_maps/%s" % (fname_part,), exist_ok=True)
-        self.save_image(hsv_img, "attention_maps/%s/attention_map_%i.png" % (fname_part, step,), pix_format="HSV")
+        val = {"switch_temperature": temp}
+        for i in range(len(self.running_sum)):
+            val["switch_%i_specificity" % (i,)] = self.running_sum[i] / self.running_count
+            self.running_sum[i] = 0
+        self.running_count = 0
+        return val
+
+
+class ConfigurableSwitchedResidualGenerator(nn.Module):
+    def __init__(self, switch_filters, switch_depths, trans_counts, trans_kernel_sizes, trans_layers, initial_temp=20, final_temperature_step=50000):
+        super(ConfigurableSwitchedResidualGenerator, self).__init__()
+        switches = []
+        for filters, depth, trans_count, kernel, layers in zip(switch_filters, switch_depths, trans_counts, trans_kernel_sizes, trans_layers):
+            switches.append(SwitchComputer(3, filters, functools.partial(ResidualBranch, 3, 3, kernel_size=kernel, depth=layers), trans_count, depth, initial_temp))
+        initialize_weights(switches, 1)
+        # Initialize the transforms with a lesser weight, since they are repeatedly added on to the resultant image.
+        initialize_weights([s.transforms for s in switches], .05)
+        self.switches = nn.ModuleList(switches)
+        self.transformation_counts = trans_counts
+        self.init_temperature = initial_temp
+        self.final_temperature_step = final_temperature_step
+        self.running_sum = [0 for i in range(len(switches))]
+        self.running_count = 0
+
+    def forward(self, x):
+        self.attentions = []
+        for i, sw in enumerate(self.switches):
+            x, att = sw.forward(x, True)
+            self.attentions.append(att)
+            spec, _ = compute_attention_specificity(att, 2)
+            self.running_sum[i] += spec
+
+        self.running_count += 1
+
+        return (x,)
+
+    def set_temperature(self, temp):
+        [sw.set_temperature(temp) for sw in self.switches]
 
     def get_debug_values(self, step):
         # Take the chance to update the temperature here.
@@ -171,10 +191,7 @@ class SwitchedResidualGenerator(nn.Module):
         self.set_temperature(temp)
 
         if step % 250 == 0:
-            self.convert_attention_indices_to_image(self.a1, 4, step, "a1")
-            self.convert_attention_indices_to_image(self.a2, 8, step, "a2")
-            self.convert_attention_indices_to_image(self.a3, 16, step, "a3", 2)
-            self.convert_attention_indices_to_image(self.a4, 32, step, "a4", 4)
+            [save_attention_to_image(self.attentions[i], self.transformation_counts[i], step, "a%i" % (i+1,), l_mult=float(self.transformation_counts[i]/4)) for i in range(len(self.switches))]
 
         val = {"switch_temperature": temp}
         for i in range(len(self.running_sum)):
diff --git a/codes/models/networks.py b/codes/models/networks.py
index f1ee145c..6c70fcef 100644
--- a/codes/models/networks.py
+++ b/codes/models/networks.py
@@ -72,6 +72,10 @@ def define_G(opt, net_key='network_G'):
     elif which_model == "SwitchedResidualGenerator":
         netG = SwitchedGen_arch.SwitchedResidualGenerator(switch_filters=opt_net['nf'], initial_temp=opt_net['temperature'],
                                                           final_temperature_step=opt_net['temperature_final_step'])
+    elif which_model == "ConfigurableSwitchedResidualGenerator":
+        netG = SwitchedGen_arch.ConfigurableSwitchedResidualGenerator(switch_filters=opt_net['switch_filters'], switch_depths=opt_net['switch_depths'], trans_counts=opt_net['trans_counts'],
+                                                                      trans_kernel_sizes=opt_net['trans_kernel_sizes'], trans_layers=opt_net['trans_layers'],
+                                                                      initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'])
 
     # image corruption
     elif which_model == 'HighToLowResNet':