From e2a146abc7c85e40c9946d9886bc3dae5a8be4fa Mon Sep 17 00:00:00 2001
From: James Betker <jbetker@gmail.com>
Date: Sat, 19 Sep 2020 10:05:25 -0600
Subject: [PATCH] Add in experiments hook

---
 codes/data/util.py                      |  2 +-
 codes/models/ExtensibleTrainer.py       | 20 +++++-
 codes/models/experiments/experiments.py | 83 +++++++++++++++++++++++++
 3 files changed, 103 insertions(+), 2 deletions(-)
 create mode 100644 codes/models/experiments/experiments.py

diff --git a/codes/data/util.py b/codes/data/util.py
index d7f7d302..7d42c0d1 100644
--- a/codes/data/util.py
+++ b/codes/data/util.py
@@ -25,7 +25,7 @@ def _get_paths_from_images(path):
     images = []
     for dirpath, _, fnames in sorted(os.walk(path)):
         for fname in sorted(fnames):
-            if is_image_file(fname):
+            if is_image_file(fname) and 'ref.jpg' not in fname:
                 img_path = os.path.join(dirpath, fname)
                 images.append(img_path)
     assert images, '{:s} has no valid image file'.format(path)
diff --git a/codes/models/ExtensibleTrainer.py b/codes/models/ExtensibleTrainer.py
index 2484c939..3a56cdaf 100644
--- a/codes/models/ExtensibleTrainer.py
+++ b/codes/models/ExtensibleTrainer.py
@@ -10,6 +10,7 @@ import models.lr_scheduler as lr_scheduler
 import models.networks as networks
 from models.base_model import BaseModel
 from models.steps.steps import ConfigurableStep
+from models.experiments.experiments import get_experiment_for_name
 import torchvision.utils as utils
 
 logger = logging.getLogger('base')
@@ -37,6 +38,7 @@ class ExtensibleTrainer(BaseModel):
 
         self.netsG = {}
         self.netsD = {}
+        # Note that this is on the chopping block. It should be integrated into an injection point.
         self.netF = networks.define_F().to(self.device)  # Used to compute feature loss.
         for name, net in opt['networks'].items():
             # Trainable is a required parameter, but the default is simply true. Set it here.
@@ -56,9 +58,11 @@ class ExtensibleTrainer(BaseModel):
                 new_net.eval()
 
         # Initialize the train/eval steps
+        self.step_names = []
         self.steps = []
         for step_name, step in opt['steps'].items():
             step = ConfigurableStep(step, self.env)
+            self.step_names.append(step_name)  # This could be an OrderedDict, but it's a PITA to integrate with AMP below.
             self.steps.append(step)
 
         # step.define_optimizers() relies on the networks being placed in the env, so put them there. Even though
@@ -91,11 +95,12 @@ class ExtensibleTrainer(BaseModel):
             amp_opts = self.optimizers
             self.env['amp'] = False
 
-        # Unwrap steps & netF
+        # Unwrap steps & netF & optimizers
         self.netF = amp_nets[len(total_nets)]
         assert(len(self.steps) == len(amp_nets[len(total_nets)+1:]))
         self.steps = amp_nets[len(total_nets)+1:]
         amp_nets = amp_nets[:len(total_nets)]
+        self.optimizers = amp_opts
 
         # DataParallel
         dnets = []
@@ -133,6 +138,11 @@ class ExtensibleTrainer(BaseModel):
         self.print_network()  # print network
         self.load()  # load G and D if needed
 
+        # Load experiments
+        self.experiments = []
+        if 'experiments' in opt.keys():
+            self.experiments = [get_experiment_for_name(e) for e in op['experiments']]
+
         # Setting this to false triggers SRGAN to call the models update_model() function on the first iteration.
         self.updated = True
 
@@ -185,6 +195,9 @@ class ExtensibleTrainer(BaseModel):
                         p.requires_grad = False
             assert enabled == len(nets_to_train)
 
+            # Update experiments
+            [e.before_step(self.opt, self.step_names[step_num], self.env, nets_to_train, state) for e in self.experiments]
+
             for o in s.get_optimizers():
                 o.zero_grad()
 
@@ -205,7 +218,9 @@ class ExtensibleTrainer(BaseModel):
                 state[k] = v
 
             # And finally perform optimization.
+            [e.before_optimize(state) for e in self.experiments]
             s.do_step()
+            [e.after_optimize(state) for e in self.experiments]
 
         # Record visual outputs for usage in debugging and testing.
         if 'visuals' in self.opt['logger'].keys():
@@ -252,6 +267,9 @@ class ExtensibleTrainer(BaseModel):
         for s in self.steps:
             log.update(s.get_metrics())
 
+        for e in self.experiments:
+            log.update(e.get_log_data())
+
         # Some generators can do their own metric logging.
         for net_name, net in self.networks.items():
             if hasattr(net.module, "get_debug_values"):
diff --git a/codes/models/experiments/experiments.py b/codes/models/experiments/experiments.py
new file mode 100644
index 00000000..fec2bf82
--- /dev/null
+++ b/codes/models/experiments/experiments.py
@@ -0,0 +1,83 @@
+import torch
+
+def get_experiment_for_name(name):
+    return Experiment()
+
+
+# Experiments are ways to add hooks into the ExtensibleTrainer training process with the intent of reporting the
+# inner workings of the process in a custom manner that is unsuitable for addition elsewhere.
+class Experiment:
+    def before_step(self, opt, step_name, env, nets_to_train, pre_state):
+        pass
+
+    def before_optimize(self, state):
+        pass
+
+    def after_optimize(self, state):
+        pass
+
+    def get_log_data(self):
+        pass
+
+
+class ModelParameterDepthTrackerMetrics(Experiment):
+    # Subclasses should implement these two methods:
+    def get_network_and_step_names(self):
+        # Subclasses should return the network being debugged and the step name it is trained in. return: (net, stepname)
+        pass
+    def get_layers_to_debug(self, env, net, state):
+        # Subclasses should populate self.layers with a list of per-layer nn.Modules here.
+        pass
+
+
+    def before_step(self, opt, step_name, env, nets_to_train, pre_state):
+        self.net, step = self.get_network_and_step_names()
+        self.activate = self.net in nets_to_train and step == step_name and self.step_num % opt['logger']['print_freq'] == 0
+        if self.activate:
+            layers = self.get_layers_to_debug(env, env['networks'][self.net], pre_state)
+            self.params = []
+            for l in layers:
+                lparams = []
+                for k, v in env['networks'][self.net].named_parameters():  # can optimize for a part of the model
+                    if v.requires_grad:
+                        lparams.append(v)
+                self.params.append(lparams)
+
+    def before_optimize(self, state):
+        self.cached_params = []
+        for l in self.params:
+            lparams = []
+            for p in l:
+                lparams.append(p.value().cpu())
+            self.cached_params.append(lparams)
+
+    def after_optimize(self, state):
+        # Compute the abs mean difference across the params.
+        self.layer_means = []
+        for l, lc in zip(self.params, self.cached_params):
+            sum = torch.tensor(0)
+            for p, pc in zip(l, lc):
+                sum += torch.abs(pc - p.value().cpu())
+            self.layer_means.append(sum / len(l))
+
+    def get_log_data(self):
+        return {'%s_layer_update_means_histogram' % (self.net,): self.layer_means}
+
+
+class DiscriminatorParameterTracker(ModelParameterDepthTrackerMetrics):
+    def get_network_and_step_names(self):
+        return "feature_discriminator", "feature_discriminator"
+
+    def get_layers_to_debug(self, env, net, state):
+        return [net.ref_head.conv0_0,
+                net.ref_head.conv0_1,
+                net.ref_head.conv1_0,
+                net.ref_head.conv1_1,
+                net.ref_head.conv2_0,
+                net.ref_head.conv2_1,
+                net.ref_head.conv3_0,
+                net.ref_head.conv3_1,
+                net.ref_head.conv4_0,
+                net.ref_head.conv4_1,
+                net.linear1,
+                net.output_linears]
\ No newline at end of file