From b10bcf64368d04dd33c53ad1788c8fdbfa76a4bd Mon Sep 17 00:00:00 2001
From: James Betker <jbetker@gmail.com>
Date: Mon, 23 Nov 2020 11:31:11 -0700
Subject: [PATCH] Rework stylegan_for_sr to incorporate structure as an adain
 block

---
 codes/models/archs/RRDBNet_arch.py            |   1 -
 .../models/archs/srflow_orig/RRDBNet_arch.py  |  26 ++-
 codes/models/archs/stylegan/stylegan2.py      | 194 +++++++++++++++---
 codes/models/eval/sr_style.py                 |   5 +-
 codes/models/networks.py                      |   5 +
 codes/models/steps/injectors.py               |   7 +-
 codes/scripts/extract_square_images.py        |   5 +-
 codes/train.py                                |   6 +-
 codes/train2.py                               |   2 +-
 9 files changed, 209 insertions(+), 42 deletions(-)

diff --git a/codes/models/archs/RRDBNet_arch.py b/codes/models/archs/RRDBNet_arch.py
index a7fec89d..cddd1c4f 100644
--- a/codes/models/archs/RRDBNet_arch.py
+++ b/codes/models/archs/RRDBNet_arch.py
@@ -222,4 +222,3 @@ class RRDBNet(nn.Module):
         for i, bm in enumerate(self.body):
             if hasattr(bm, 'bypass_map'):
                 torchvision.utils.save_image(bm.bypass_map.cpu().float(), os.path.join(path, "%i_bypass_%i.png" % (step, i+1)))
-
diff --git a/codes/models/archs/srflow_orig/RRDBNet_arch.py b/codes/models/archs/srflow_orig/RRDBNet_arch.py
index 607f4e0e..3828d7b8 100644
--- a/codes/models/archs/srflow_orig/RRDBNet_arch.py
+++ b/codes/models/archs/srflow_orig/RRDBNet_arch.py
@@ -3,7 +3,7 @@ import torch
 import torch.nn as nn
 import torch.nn.functional as F
 import models.archs.srflow_orig.module_util as mutil
-from models.archs.arch_util import default_init_weights, ConvGnSilu
+from models.archs.arch_util import default_init_weights, ConvGnSilu, ConvGnLelu
 from utils.util import opt_get
 
 
@@ -231,3 +231,27 @@ class RRDBNet(nn.Module):
             return results
         else:
             return out
+
+
+class RRDBLatentWrapper(nn.Module):
+    def __init__(self, in_nc, out_nc, nf, nb, with_bypass, blocks, pretrain_rrdb_path=None, gc=32, scale=4):
+        super().__init__()
+        self.with_bypass = with_bypass
+        self.blocks = blocks
+        fake_opt = { 'networks': {'generator': {'flow': {'stackRRDB': {'blocks': blocks}}, 'rrdb_bypass': with_bypass}}}
+        self.wrappedRRDB = RRDBNet(in_nc, out_nc, nf, nb, gc, scale, fake_opt)
+        if pretrain_rrdb_path is not None:
+            rrdb_state_dict = torch.load(pretrain_rrdb_path)
+            self.wrappedRRDB.load_state_dict(rrdb_state_dict, strict=True)
+        out_dim = nf * (len(blocks) + 1)
+        self.postprocess = nn.Sequential(ConvGnLelu(out_dim, out_dim, kernel_size=1, bias=True, activation=True, norm=True),
+                                         ConvGnLelu(out_dim, out_dim, kernel_size=1, bias=True, activation=True, norm=True),
+                                         ConvGnLelu(out_dim, out_dim, kernel_size=1, bias=True, activation=False, norm=False))
+
+    def forward(self, lr):
+        rrdbResults = self.wrappedRRDB(lr, get_steps=True)
+        blocklist = [rrdbResults["block_{}".format(idx)] for idx in self.blocks]
+        blocklist.append(rrdbResults['last_lr_fea'])
+        fea = torch.cat(blocklist, dim=1)
+        fea = self.postprocess(fea)
+        return fea
\ No newline at end of file
diff --git a/codes/models/archs/stylegan/stylegan2.py b/codes/models/archs/stylegan/stylegan2.py
index 16735af3..8cd9365c 100644
--- a/codes/models/archs/stylegan/stylegan2.py
+++ b/codes/models/archs/stylegan/stylegan2.py
@@ -1,3 +1,4 @@
+import functools
 import math
 import multiprocessing
 from contextlib import contextmanager, ExitStack
@@ -371,6 +372,76 @@ class RGBBlock(nn.Module):
         return x
 
 
+class AdaptiveInstanceNorm(nn.Module):
+    def __init__(self, in_channel, style_dim):
+        super().__init__()
+        from models.archs.arch_util import ConvGnLelu
+        self.style2scale = ConvGnLelu(style_dim, in_channel, kernel_size=1, norm=False, activation=False, bias=True)
+        self.style2bias = ConvGnLelu(style_dim, in_channel, kernel_size=1, norm=False, activation=False, bias=True, weight_init_factor=0)
+        self.norm = nn.InstanceNorm2d(in_channel)
+
+    def forward(self, input, style):
+        gamma = self.style2scale(style)
+        beta = self.style2bias(style)
+        out = self.norm(input)
+        out = gamma * out + beta
+        return out
+
+
+class NoiseInjection(nn.Module):
+    def __init__(self, channel):
+        super().__init__()
+        self.weight = nn.Parameter(torch.zeros(1, channel, 1, 1))
+
+    def forward(self, image, noise):
+        return image + self.weight * noise
+
+
+class EqualLR:
+    def __init__(self, name):
+        self.name = name
+
+    def compute_weight(self, module):
+        weight = getattr(module, self.name + '_orig')
+        fan_in = weight.data.size(1) * weight.data[0][0].numel()
+
+        return weight * math.sqrt(2 / fan_in)
+
+    @staticmethod
+    def apply(module, name):
+        fn = EqualLR(name)
+
+        weight = getattr(module, name)
+        del module._parameters[name]
+        module.register_parameter(name + '_orig', nn.Parameter(weight.data))
+        module.register_forward_pre_hook(fn)
+
+        return fn
+
+    def __call__(self, module, input):
+        weight = self.compute_weight(module)
+        setattr(module, self.name, weight)
+
+
+def equal_lr(module, name='weight'):
+    EqualLR.apply(module, name)
+
+    return module
+
+
+class EqualConv2d(nn.Module):
+    def __init__(self, *args, **kwargs):
+        super().__init__()
+
+        conv = nn.Conv2d(*args, **kwargs)
+        conv.weight.data.normal_()
+        conv.bias.data.zero_()
+        self.conv = equal_lr(conv)
+
+    def forward(self, input):
+        return self.conv(input)
+
+
 class Conv2DMod(nn.Module):
     def __init__(self, in_chan, out_chan, kernel, demod=True, stride=1, dilation=1, **kwargs):
         super().__init__()
@@ -408,6 +479,54 @@ class Conv2DMod(nn.Module):
         return x
 
 
+class GeneratorBlockWithStructure(nn.Module):
+    def __init__(self, latent_dim, input_channels, filters, upsample=True, upsample_rgb=True, rgba=False):
+        super().__init__()
+        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) if upsample else None
+
+        # Uses stylegan1 style blocks for injecting structural latent.
+        self.conv0 = EqualConv2d(input_channels, filters, 3, padding=1)
+        self.to_noise0 = nn.Linear(1, filters)
+        self.noise0 = equal_lr(NoiseInjection(filters))
+        self.adain0 = AdaptiveInstanceNorm(filters, latent_dim)
+
+        self.to_style1 = nn.Linear(latent_dim, filters)
+        self.to_noise1 = nn.Linear(1, filters)
+        self.conv1 = Conv2DMod(filters, filters, 3)
+
+        self.to_style2 = nn.Linear(latent_dim, filters)
+        self.to_noise2 = nn.Linear(1, filters)
+        self.conv2 = Conv2DMod(filters, filters, 3)
+
+        self.activation = leaky_relu()
+        self.to_rgb = RGBBlock(latent_dim, filters, upsample_rgb, rgba)
+
+    def forward(self, x, prev_rgb, istyle, inoise, structure_input):
+        if exists(self.upsample):
+            x = self.upsample(x)
+
+        inoise = inoise[:, :x.shape[2], :x.shape[3], :]
+        noise0 = self.to_noise0(inoise).permute((0, 3, 1, 2))
+        noise1 = self.to_noise1(inoise).permute((0, 3, 1, 2))
+        noise2 = self.to_noise2(inoise).permute((0, 3, 1, 2))
+
+        structure = torch.nn.functional.interpolate(structure_input, size=x.shape[2:], mode="nearest")
+        x = self.conv0(x)
+        x = self.noise0(x, noise0)
+        x = self.adain0(x, structure)
+
+        style1 = self.to_style1(istyle)
+        x = self.conv1(x, style1)
+        x = self.activation(x + noise1)
+
+        style2 = self.to_style2(istyle)
+        x = self.conv2(x, style2)
+        x = self.activation(x + noise2)
+
+        rgb = self.to_rgb(x, prev_rgb, istyle)
+        return x, rgb
+
+
 class GeneratorBlock(nn.Module):
     def __init__(self, latent_dim, input_channels, filters, upsample=True, upsample_rgb=True, rgba=False, structure_input=False):
         super().__init__()
@@ -453,32 +572,6 @@ class GeneratorBlock(nn.Module):
         return x, rgb
 
 
-class DiscriminatorBlock(nn.Module):
-    def __init__(self, input_channels, filters, downsample=True):
-        super().__init__()
-        self.conv_res = nn.Conv2d(input_channels, filters, 1, stride=(2 if downsample else 1))
-
-        self.net = nn.Sequential(
-            nn.Conv2d(input_channels, filters, 3, padding=1),
-            leaky_relu(),
-            nn.Conv2d(filters, filters, 3, padding=1),
-            leaky_relu()
-        )
-
-        self.downsample = nn.Sequential(
-            Blur(),
-            nn.Conv2d(filters, filters, 3, padding=1, stride=2)
-        ) if downsample else None
-
-    def forward(self, x):
-        res = self.conv_res(x)
-        x = self.net(x)
-        if exists(self.downsample):
-            x = self.downsample(x)
-        x = (x + res) * (1 / math.sqrt(2))
-        return x
-
-
 class Generator(nn.Module):
     def __init__(self, image_size, latent_dim, network_capacity=16, transparent=False, attn_layers=[], no_const=False,
                  fmap_max=512, structure_input=False):
@@ -515,18 +608,22 @@ class Generator(nn.Module):
 
             self.attns.append(attn_fn)
 
-            block = GeneratorBlock(
+            if structure_input:
+                block_fn = GeneratorBlockWithStructure
+            else:
+                block_fn = GeneratorBlock
+
+            block = block_fn(
                 latent_dim,
                 in_chan,
                 out_chan,
                 upsample=not_first,
                 upsample_rgb=not_last,
-                rgba=transparent,
-                structure_input=structure_input
+                rgba=transparent
             )
             self.blocks.append(block)
 
-    def forward(self, styles, input_noise, structure_input=None):
+    def forward(self, styles, input_noise, structure_input=None, starting_shape=None):
         batch_size = styles.shape[0]
         image_size = self.image_size
 
@@ -535,6 +632,8 @@ class Generator(nn.Module):
             x = self.to_initial_block(avg_style)
         else:
             x = self.initial_block.expand(batch_size, -1, -1, -1)
+        if starting_shape is not None:
+            x = F.interpolate(x, size=starting_shape, mode="bilinear")
 
         rgb = None
         styles = styles.transpose(0, 1)
@@ -591,7 +690,7 @@ class StyleGan2GeneratorWithLatent(nn.Module):
 
     # To use per the stylegan paper, input should be uniform noise. This gen takes it in as a normal "image" format:
     # b,f,h,w.
-    def forward(self, x, structure_input=None):
+    def forward(self, x, structure_input=None, fit_starting_shape_to_structure=False):
         b, f, h, w = x.shape
 
         full_random_latents = True
@@ -614,12 +713,15 @@ class StyleGan2GeneratorWithLatent(nn.Module):
             w_space = self.latent_to_w(self.vectorizer, style)
             w_styles = self.styles_def_to_tensor(w_space)
 
+        starting_shape = None
+        if fit_starting_shape_to_structure:
+            starting_shape = (x.shape[2] // 32, x.shape[3] // 32)
         # The underlying model expects the noise as b,h,w,1. Make it so.
-        return self.gen(w_styles, x[:,0,:,:].unsqueeze(dim=3), structure_input), w_styles
+        return self.gen(w_styles, x[:,0,:,:].unsqueeze(dim=3), structure_input, starting_shape), w_styles
 
     def _init_weights(self):
         for m in self.modules():
-            if type(m) in {nn.Conv2d, nn.Linear}:
+            if type(m) in {nn.Conv2d, nn.Linear} and hasattr(m, 'weight'):
                 nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')
 
         for block in self.gen.blocks:
@@ -629,6 +731,32 @@ class StyleGan2GeneratorWithLatent(nn.Module):
             nn.init.zeros_(block.to_noise2.bias)
 
 
+class DiscriminatorBlock(nn.Module):
+    def __init__(self, input_channels, filters, downsample=True):
+        super().__init__()
+        self.conv_res = nn.Conv2d(input_channels, filters, 1, stride=(2 if downsample else 1))
+
+        self.net = nn.Sequential(
+            nn.Conv2d(input_channels, filters, 3, padding=1),
+            leaky_relu(),
+            nn.Conv2d(filters, filters, 3, padding=1),
+            leaky_relu()
+        )
+
+        self.downsample = nn.Sequential(
+            Blur(),
+            nn.Conv2d(filters, filters, 3, padding=1, stride=2)
+        ) if downsample else None
+
+    def forward(self, x):
+        res = self.conv_res(x)
+        x = self.net(x)
+        if exists(self.downsample):
+            x = self.downsample(x)
+        x = (x + res) * (1 / math.sqrt(2))
+        return x
+
+
 class StyleGan2Discriminator(nn.Module):
     def __init__(self, image_size, network_capacity=16, fq_layers=[], fq_dict_size=256, attn_layers=[],
                  transparent=False, fmap_max=512, input_filters=3):
diff --git a/codes/models/eval/sr_style.py b/codes/models/eval/sr_style.py
index ca46fae8..45ca70ee 100644
--- a/codes/models/eval/sr_style.py
+++ b/codes/models/eval/sr_style.py
@@ -22,6 +22,7 @@ class SrStyleTransferEvaluator(evaluator.Evaluator):
         self.im_sz = opt_eval['image_size']
         self.scale = opt_eval['scale']
         self.fid_real_samples = opt_eval['real_fid_path']
+        self.embedding_generator = opt_eval['embedding_generator']
         self.gen_output_index = opt_eval['gen_index'] if 'gen_index' in opt_eval.keys() else 0
         self.dataset = Stylegan2Dataset({'path': self.fid_real_samples,
                                          'target_size': self.im_sz,
@@ -30,6 +31,7 @@ class SrStyleTransferEvaluator(evaluator.Evaluator):
         self.sampler = BatchSampler(self.dataset, self.batch_sz, False)
 
     def perform_eval(self):
+        embedding_generator = self.env['generators'][self.embedding_generator]
         fid_fake_path = osp.join(self.env['base_path'], "..", "fid_fake", str(self.env["step"]))
         os.makedirs(fid_fake_path, exist_ok=True)
         fid_real_path = osp.join(self.env['base_path'], "..", "fid_real", str(self.env["step"]))
@@ -40,7 +42,8 @@ class SrStyleTransferEvaluator(evaluator.Evaluator):
             batch_hq = [e['GT'] for e in batch]
             batch_hq = torch.stack(batch_hq, dim=0).to(self.env['device'])
             resized_batch = torch.nn.functional.interpolate(batch_hq, scale_factor=1/self.scale, mode="area")
-            gen = self.model(noise, resized_batch)
+            embedding = embedding_generator(resized_batch)
+            gen = self.model(noise, embedding)
             if not isinstance(gen, list) and not isinstance(gen, tuple):
                 gen = [gen]
             gen = gen[self.gen_output_index]
diff --git a/codes/models/networks.py b/codes/models/networks.py
index 75130989..41aa7ab2 100644
--- a/codes/models/networks.py
+++ b/codes/models/networks.py
@@ -148,6 +148,11 @@ def define_G(opt, opt_net, scale=None):
         from models.archs.srflow_orig import SRFlowNet_arch
         netG = SRFlowNet_arch.SRFlowNet(in_nc=3, out_nc=3, nf=opt_net['nf'], nb=opt_net['nb'], scale=opt['scale'],
                                      K=opt_net['K'], opt=opt)
+    elif which_model == 'rrdb_latent_wrapper':
+        from models.archs.srflow_orig.RRDBNet_arch import RRDBLatentWrapper
+        netG = RRDBLatentWrapper(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
+                                  nf=opt_net['nf'], nb=opt_net['nb'], with_bypass=opt_net['with_bypass'],
+                                 blocks=opt_net['blocks_for_latent'], scale=opt_net['scale'], pretrain_rrdb_path=opt_net['pretrain_path'])
     else:
         raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))
     return netG
diff --git a/codes/models/steps/injectors.py b/codes/models/steps/injectors.py
index 3cb9366a..06165aba 100644
--- a/codes/models/steps/injectors.py
+++ b/codes/models/steps/injectors.py
@@ -78,15 +78,20 @@ class Injector(torch.nn.Module):
 class ImageGeneratorInjector(Injector):
     def __init__(self, opt, env):
         super(ImageGeneratorInjector, self).__init__(opt, env)
+        self.grad = opt['grad'] if 'grad' in opt.keys() else True
 
     def forward(self, state):
         gen = self.env['generators'][self.opt['generator']]
         with autocast(enabled=self.env['opt']['fp16']):
             if isinstance(self.input, list):
                 params = extract_params_from_state(self.input, state)
+            else:
+                params = [state[self.input]]
+            if self.grad:
                 results = gen(*params)
             else:
-                results = gen(state[self.input])
+                with torch.no_grad():
+                    results = gen(*params)
         new_state = {}
         if isinstance(self.output, list):
             # Only dereference tuples or lists, not tensors.
diff --git a/codes/scripts/extract_square_images.py b/codes/scripts/extract_square_images.py
index 1a32d912..f037733e 100644
--- a/codes/scripts/extract_square_images.py
+++ b/codes/scripts/extract_square_images.py
@@ -13,7 +13,7 @@ import torch
 def main():
     split_img = False
     opt = {}
-    opt['n_thread'] = 5
+    opt['n_thread'] = 20
     opt['compression_level'] = 90  # JPEG compression quality rating.
     # CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and longer
     # compression time. If read raw images during training, use 0 for faster IO speed.
@@ -46,6 +46,9 @@ class TiledDataset(data.Dataset):
         img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
 
         # Greyscale not supported.
+        if img is None:
+            print("Error with ", path)
+            return None
         if len(img.shape) == 2:
             return None
         h, w, c = img.shape
diff --git a/codes/train.py b/codes/train.py
index bb0be9cb..38c1bf2b 100644
--- a/codes/train.py
+++ b/codes/train.py
@@ -31,8 +31,8 @@ class Trainer:
 
     def init(self, opt, launcher, all_networks={}):
         self._profile = False
-        self.val_compute_psnr = opt['eval']['compute_psnr'] if 'compute_psnr' in opt['eval'] else True
-        self.val_compute_fea = opt['eval']['compute_fea'] if 'compute_fea' in opt['eval'] else True
+        self.val_compute_psnr = opt['eval']['compute_psnr'] if 'compute_psnr' in opt['eval'].keys() else True
+        self.val_compute_fea = opt['eval']['compute_fea'] if 'compute_fea' in opt['eval'].keys() else True
 
         #### loading resume state if exists
         if opt['path'].get('resume_state', None):
@@ -291,7 +291,7 @@ class Trainer:
 
 if __name__ == '__main__':
     parser = argparse.ArgumentParser()
-    parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_srflow.yml')
+    parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_stylegan2_for_sr_v2.yml')
     parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
     parser.add_argument('--local_rank', type=int, default=0)
     args = parser.parse_args()
diff --git a/codes/train2.py b/codes/train2.py
index be751986..ecef64c3 100644
--- a/codes/train2.py
+++ b/codes/train2.py
@@ -291,7 +291,7 @@ class Trainer:
 
 if __name__ == '__main__':
     parser = argparse.ArgumentParser()
-    parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_imgset_srflow.yml')
+    parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_teco_vix_srg2_classic_proper_disc.yml')
     parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
     parser.add_argument('--local_rank', type=int, default=0)
     args = parser.parse_args()