From e3adafbeaceab43858e93f7b21e22ae48dae8549 Mon Sep 17 00:00:00 2001
From: James Betker <jbetker@gmail.com>
Date: Wed, 22 Jul 2020 11:39:45 -0600
Subject: [PATCH] Add convert_model.py and a hacky way to add extra layers to a
 model

---
 codes/models/SRGAN_model.py               |  7 +++
 codes/models/archs/ProgressiveSrg_arch.py | 26 ++++++++
 codes/utils/convert_model.py              | 75 +++++++++++++++++++++++
 3 files changed, 108 insertions(+)
 create mode 100644 codes/utils/convert_model.py

diff --git a/codes/models/SRGAN_model.py b/codes/models/SRGAN_model.py
index 697c5575..60baf25b 100644
--- a/codes/models/SRGAN_model.py
+++ b/codes/models/SRGAN_model.py
@@ -182,6 +182,9 @@ class SRGANModel(BaseModel):
         self.load()  # load G and D if needed
         self.load_random_corruptor()
 
+        # Setting this to false triggers SRGAN to call the models update_model() function on the first iteration.
+        self.updated = True
+
     def feed_data(self, data, need_GT=True):
         _profile = True
         if _profile:
@@ -203,6 +206,10 @@ class SRGANModel(BaseModel):
             self.var_ref = [t.to(self.device) for t in torch.chunk(input_ref, chunks=self.mega_batch_factor, dim=0)]
             self.pix = [t.to(self.device) for t in torch.chunk(data['PIX'], chunks=self.mega_batch_factor, dim=0)]
 
+        if not self.updated:
+            self.netG.module.update_model(self.optimizer_G, self.schedulers[0])
+            self.updated = True
+
     def optimize_parameters(self, step):
         _profile = False
         if _profile:
diff --git a/codes/models/archs/ProgressiveSrg_arch.py b/codes/models/archs/ProgressiveSrg_arch.py
index df09cdce..ee0043f8 100644
--- a/codes/models/archs/ProgressiveSrg_arch.py
+++ b/codes/models/archs/ProgressiveSrg_arch.py
@@ -75,6 +75,32 @@ class GrowingSRGBase(nn.Module):
             param_groups.append({'params': sw_param_group})
         return param_groups
 
+    # This is a hacky way of modifying the underlying model while training. Since changing the model means changing
+    # the optimizer and the scheduler, these things are fed in. For ProgressiveSrg, this function adds an additional
+# switch to the end of the chain with depth=3 and an online time set at the end fo the function.
+    def update_model(self, opt, sched):
+        multiplx_fn = functools.partial(srg.ConvBasisMultiplexer, self.transformation_filters, self.switch_filters,
+                                    3, self.switch_processing_layers, self.transformation_counts)
+        pretransform_fn = functools.partial(ConvGnLelu, self.transformation_filters, self.transformation_filters, norm=False,
+                                            bias=False, weight_init_factor=.1)
+        transform_fn = functools.partial(srg.MultiConvBlock, self.transformation_filters, int(self.transformation_filters * 1.5),
+                                         self.transformation_filters, kernel_size=3, depth=self.trans_layers,
+                                         weight_init_factor=.1)
+        new_sw = srg.ConfigurableSwitchComputer(self.transformation_filters, multiplx_fn,
+                                                           pre_transform_block=pretransform_fn,
+                                                           transform_block=transform_fn,
+                                                           transform_count=self.transformation_counts, init_temp=self.init_temperature,
+                                                           add_scalable_noise_to_transforms=self.add_noise_to_transform,
+                                                           attention_norm=False).to('cuda')
+        self.progressive_switches.append(new_sw)
+        new_sw_param_group = []
+        for k, v in new_sw.named_parameters():
+            if v.requires_grad:
+                new_sw_param_group.append(v)
+        opt.add_param_group({'params': new_sw_param_group})
+        self.progressive_schedule.append(150000)
+        sched.group_starts.append(150000)
+
     def get_progressive_starts(self):
         # The base param group starts at step 0, the rest are defined via progressive_switches.
         return [0] + self.progressive_schedule
diff --git a/codes/utils/convert_model.py b/codes/utils/convert_model.py
new file mode 100644
index 00000000..6d31aac6
--- /dev/null
+++ b/codes/utils/convert_model.py
@@ -0,0 +1,75 @@
+# Tool that can be used to add a new layer into an existing model save file. Primarily useful for "progressive"
+# models which can be trained piecemeal.
+
+import options.options as option
+from models import create_model
+import torch
+import os
+
+def get_model_for_opt_file(filename):
+    opt = option.parse(filename, is_train=True)
+    opt = option.dict_to_nonedict(opt)
+    model = create_model(opt)
+    return model, opt
+
+def copy_state_dict_list(l_from, l_to):
+    for i, v in enumerate(l_from):
+        if isinstance(v, list):
+            copy_state_dict_list(v, l_to[i])
+        elif isinstance(v, dict):
+            copy_state_dict(v, l_to[i])
+        else:
+            l_to[i] = v
+
+def copy_state_dict(dict_from, dict_to):
+    for k in dict_from.keys():
+        if k == 'optimizers':
+            for j in range(len(dict_from[k][0]['param_groups'])):
+                for p in dict_to[k][0]['param_groups'][j]['params']:
+                    del dict_to[k][0]['state']
+                dict_to[k][0]['param_groups'][j] = dict_from[k][0]['param_groups'][j]
+            dict_to[k][0]['state'].update(dict_from[k][0]['state'])
+            print(len(dict_from[k][0].keys()), dict_from[k][0].keys())
+            print(len(dict_to[k][0].keys()), dict_to[k][0].keys())
+        assert k in dict_to.keys()
+        if isinstance(dict_from[k], dict):
+            copy_state_dict(dict_from[k], dict_to[k])
+        elif isinstance(dict_from[k], list):
+            copy_state_dict_list(dict_from[k], dict_to[k])
+        else:
+            dict_to[k] = dict_from[k]
+    return dict_to
+
+if __name__ == "__main__":
+    os.chdir("..")
+    torch.backends.cudnn.benchmark = True
+    want_just_images = True
+    model_from, opt_from = get_model_for_opt_file("../options/train_imgset_pixgan_progressive_srg2.yml")
+    model_to, _ = get_model_for_opt_file("../options/train_imgset_pixgan_progressive_srg2_.yml")
+
+    '''
+    model_to.netG.module.update_for_step(1000000000000)
+    l = torch.nn.MSELoss()
+    o, _ = model_to.netG(torch.randn(1, 3, 64, 64))
+    l(o, torch.randn_like(o)).backward()
+    model_to.optimizer_G.step()
+    o = model_to.netD(torch.randn(1, 3, 128, 128))
+    l(o, torch.randn_like(o)).backward()
+    model_to.optimizer_D.step()
+    '''
+
+    torch.save(copy_state_dict(model_from.netG.state_dict(), model_to.netG.state_dict()), "converted_g.pth")
+    torch.save(copy_state_dict(model_from.netD.state_dict(), model_to.netD.state_dict()), "converted_d.pth")
+
+    # Also convert the state.
+    resume_state_from = torch.load(opt_from['path']['resume_state'])
+    resume_state_to = model_to.save_training_state(0, 0, return_state=True)
+    resume_state_from['optimizers'][0]['param_groups'].append(resume_state_to['optimizers'][0]['param_groups'][-1])
+    torch.save(resume_state_from, "converted_state.pth")
+
+
+
+
+
+
+