From e992e18767beb9deda4d49f36d08de64c5f54644 Mon Sep 17 00:00:00 2001
From: James Betker <jbetker@gmail.com>
Date: Fri, 1 Jan 2021 11:59:36 -0700
Subject: [PATCH] Add initial_stride term to style_sr

Also fix fid and a networks.py issue.
---
 codes/models/improve_rrdb/styled_sr.py | 21 +++++++--------
 codes/trainer/eval/sr_fid.py           | 37 ++++++++++++++++++++++++--
 codes/trainer/networks.py              |  8 +++---
 3 files changed, 49 insertions(+), 17 deletions(-)

diff --git a/codes/models/improve_rrdb/styled_sr.py b/codes/models/improve_rrdb/styled_sr.py
index e970dec3..dea13bc2 100644
--- a/codes/models/improve_rrdb/styled_sr.py
+++ b/codes/models/improve_rrdb/styled_sr.py
@@ -9,7 +9,7 @@ from models.RRDBNet_arch import RRDB
 from models.arch_util import ConvGnLelu, default_init_weights
 from models.stylegan.stylegan2_lucidrains import StyleVectorizer, GeneratorBlock, Conv2DMod, leaky_relu, Blur
 from trainer.networks import register_model
-from utils.util import checkpoint
+from utils.util import checkpoint, opt_get
 
 
 class EncoderRRDB(nn.Module):
@@ -35,10 +35,10 @@ class EncoderRRDB(nn.Module):
 
 
 class StyledSrEncoder(nn.Module):
-    def __init__(self, fea_out=256):
+    def __init__(self, fea_out=256, initial_stride=1):
         super().__init__()
         # Current assumes fea_out=256.
-        self.initial_conv = ConvGnLelu(3, 32, kernel_size=7, norm=False, activation=False, bias=True)
+        self.initial_conv = ConvGnLelu(3, 32, kernel_size=7, stride=initial_stride, norm=False, activation=False, bias=True)
         self.rrdbs = nn.ModuleList([
            EncoderRRDB(32),
            EncoderRRDB(64),
@@ -56,7 +56,7 @@ class StyledSrEncoder(nn.Module):
 
 
 class Generator(nn.Module):
-    def __init__(self, image_size, latent_dim, transparent=False, start_level=3, upsample_levels=2):
+    def __init__(self, image_size, latent_dim, initial_stride=1, start_level=3, upsample_levels=2):
         super().__init__()
         total_levels = upsample_levels + 1  # The first level handles the raw encoder output and doesn't upsample.
         self.image_size = image_size
@@ -75,7 +75,7 @@ class Generator(nn.Module):
             8,    # 1024x1024
         ]
 
-        self.encoder = StyledSrEncoder(filters[start_level])
+        self.encoder = StyledSrEncoder(filters[start_level], initial_stride)
 
         in_out_pairs = list(zip(filters[:-1], filters[1:]))
         self.blocks = nn.ModuleList([])
@@ -88,8 +88,7 @@ class Generator(nn.Module):
                 in_chan,
                 out_chan,
                 upsample=not_first,
-                upsample_rgb=not_last,
-                rgba=transparent
+                upsample_rgb=not_last
             )
             self.blocks.append(block)
 
@@ -108,10 +107,10 @@ class Generator(nn.Module):
 
 
 class StyledSrGenerator(nn.Module):
-    def __init__(self, image_size, latent_dim=512, style_depth=8, lr_mlp=.1):
+    def __init__(self, image_size, initial_stride=1, latent_dim=512, style_depth=8, lr_mlp=.1):
         super().__init__()
         self.vectorizer = StyleVectorizer(latent_dim, style_depth, lr_mul=lr_mlp)
-        self.gen = Generator(image_size=image_size, latent_dim=latent_dim)
+        self.gen = Generator(image_size=image_size, latent_dim=latent_dim, initial_stride=initial_stride)
         self.mixed_prob = .9
         self._init_weights()
 
@@ -160,5 +159,5 @@ if __name__ == '__main__':
 
 
 @register_model
-def register_opt_styled_sr(opt_net, opt):
-    return StyledSrGenerator(128)
+def register_styled_sr(opt_net, opt):
+    return StyledSrGenerator(128, initial_stride=opt_get(opt_net, ['initial_stride'], 1))
diff --git a/codes/trainer/eval/sr_fid.py b/codes/trainer/eval/sr_fid.py
index d20393ad..b5e6e097 100644
--- a/codes/trainer/eval/sr_fid.py
+++ b/codes/trainer/eval/sr_fid.py
@@ -12,7 +12,8 @@ from data import create_dataset
 from torch.utils.data import DataLoader
 
 
-# Computes the SR FID score for a network.
+# Computes the SR FID score for a network, which is a FID score that attempts to account for structural changes the
+# generator might make from the source image.
 class SrFidEvaluator(evaluator.Evaluator):
     def __init__(self, model, opt_eval, env):
         super().__init__(model, opt_eval, env)
@@ -26,7 +27,7 @@ class SrFidEvaluator(evaluator.Evaluator):
         self.gen_output_index = opt_eval['gen_index'] if 'gen_index' in opt_eval.keys() else 0
 
     def perform_eval(self):
-        fid_fake_path = osp.join(self.env['base_path'], "..", "fid", str(self.env["step"]))
+        fid_fake_path = osp.join(self.env['base_path'], "..", "sr_fid", str(self.env["step"]))
         os.makedirs(fid_fake_path, exist_ok=True)
         counter = 0
         for batch in tqdm(self.dataloader):
@@ -49,3 +50,35 @@ class SrFidEvaluator(evaluator.Evaluator):
 
         return {"fid": fid_score.calculate_fid_given_paths([self.fid_real_samples, fid_fake_path], self.batch_sz, True,
                                                            2048)}
+
+
+# A "normal" FID computation from a generator that takes LR inputs. Does not account for structural differences at all.
+class FidForStructuralNetsEvaluator(evaluator.Evaluator):
+    def __init__(self, model, opt_eval, env):
+        super().__init__(model, opt_eval, env)
+        self.batch_sz = opt_eval['batch_size']
+        assert self.batch_sz is not None
+        self.dataset = create_dataset(opt_eval['dataset'])
+        self.scale = opt_eval['scale']
+        self.fid_real_samples = opt_eval['dataset']['paths']  # This is assumed to exist for the given dataset.
+        assert isinstance(self.fid_real_samples, str)
+        self.dataloader = DataLoader(self.dataset, self.batch_sz, shuffle=False, num_workers=1)
+        self.gen_output_index = opt_eval['gen_index'] if 'gen_index' in opt_eval.keys() else 0
+
+    def perform_eval(self):
+        fid_fake_path = osp.join(self.env['base_path'], "..", "fid", str(self.env["step"]))
+        os.makedirs(fid_fake_path, exist_ok=True)
+        counter = 0
+        for batch in tqdm(self.dataloader):
+            lq = batch['lq'].to(self.env['device'])
+            gen = self.model(lq)
+            if not isinstance(gen, list) and not isinstance(gen, tuple):
+                gen = [gen]
+            gen = gen[self.gen_output_index]
+
+            for b in range(self.batch_sz):
+                torchvision.utils.save_image(gen[b], osp.join(fid_fake_path, "%i_.png" % (counter)))
+                counter += 1
+
+        return {"fid": fid_score.calculate_fid_given_paths([self.fid_real_samples, fid_fake_path], self.batch_sz, True,
+                                                           2048)}
\ No newline at end of file
diff --git a/codes/trainer/networks.py b/codes/trainer/networks.py
index 422d6f91..3a192f6f 100644
--- a/codes/trainer/networks.py
+++ b/codes/trainer/networks.py
@@ -129,10 +129,10 @@ def define_D_net(opt_net, img_sz=None, wrap=False):
         netD = SRGAN_arch.PsnrApproximator(nf=opt_net['nf'], input_img_factor=img_sz / 128)
     elif which_model == "stylegan2_discriminator":
         attn = opt_net['attn_layers'] if 'attn_layers' in opt_net.keys() else []
-        disc = stylegan2.StyleGan2Discriminator(image_size=opt_net['image_size'], input_filters=opt_net['in_nc'], attn_layers=attn)
-        netD = stylegan2.StyleGan2Augmentor(disc, opt_net['image_size'], types=opt_net['augmentation_types'], prob=opt_net['augmentation_probability'])
-    elif which_model == "rrdb_disc":
-        netD = RRDBNet_arch.RRDBDiscriminator(opt_net['in_nc'], opt_net['nf'], opt_net['nb'], blocks_per_checkpoint=3)
+        from models.stylegan.stylegan2_lucidrains import StyleGan2Discriminator
+        disc = StyleGan2Discriminator(image_size=opt_net['image_size'], input_filters=opt_net['in_nc'], attn_layers=attn)
+        from models.stylegan.stylegan2_lucidrains import StyleGan2Augmentor
+        netD = StyleGan2Augmentor(disc, opt_net['image_size'], types=opt_net['augmentation_types'], prob=opt_net['augmentation_probability'])
     else:
         raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
     return netD