From 66e91a3d9e313036994288ae22c83124d68ac778 Mon Sep 17 00:00:00 2001
From: James Betker <jbetker@gmail.com>
Date: Thu, 30 Apr 2020 11:45:07 -0600
Subject: [PATCH] Revert "Enable skip-through connections from disc to gen"

This reverts commit b7857f35c34f9f2f7c176c2a1f3666554b5b13f9.
---
 codes/models/SRGAN_model.py                   | 35 +++++--------------
 codes/models/archs/FlatProcessorNet_arch.py   | 21 +++--------
 codes/models/archs/discriminator_vgg_arch.py  | 18 +---------
 .../train/train_GAN_blacked_corrupt.yml       | 23 ++++++------
 4 files changed, 25 insertions(+), 72 deletions(-)

diff --git a/codes/models/SRGAN_model.py b/codes/models/SRGAN_model.py
index 6b6c2855..ed9129c1 100644
--- a/codes/models/SRGAN_model.py
+++ b/codes/models/SRGAN_model.py
@@ -150,16 +150,9 @@ class SRGANModel(BaseModel):
         for p in self.netD.parameters():
             p.requires_grad = False
 
-        disc_passthrough = None
         if step > self.D_init_iters:
             self.optimizer_G.zero_grad()
-            genOut = self.netG(self.var_L)
-            if type(genOut) is tuple:
-                self.fake_H = genOut[0]
-                disc_passthrough = genOut[1]
-            else:
-                self.fake_H = genOut
-                disc_passthrough = None
+            self.fake_H = self.netG(self.var_L)
         else:
             self.fake_H = self.pix
 
@@ -186,14 +179,12 @@ class SRGANModel(BaseModel):
                 if step % self.l_fea_w_decay_steps == 0:
                     self.l_fea_w = max(self.l_fea_w_minimum, self.l_fea_w * self.l_fea_w_decay)
 
-            if disc_passthrough is not None:
-                pred_g_fake = self.netD(self.fake_H, disc_passthrough)
-            else:
-                pred_g_fake = self.netD(self.fake_H)
             if self.opt['train']['gan_type'] == 'gan':
+                pred_g_fake = self.netD(self.fake_H)
                 l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True)
             elif self.opt['train']['gan_type'] == 'ragan':
                 pred_d_real = self.netD(self.var_ref).detach()
+                pred_g_fake = self.netD(self.fake_H)
                 l_g_gan = self.l_gan_w * (
                     self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) +
                     self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2
@@ -208,21 +199,15 @@ class SRGANModel(BaseModel):
             p.requires_grad = True
 
         self.optimizer_D.zero_grad()
-        if disc_passthrough is not None:
-            dp = {}
-            for k, v in disc_passthrough.items():
-                dp[k] = v.detach()
-            pred_d_fake = self.netD(self.fake_H.detach(), dp)
-        else:
-            pred_d_fake = self.netD(self.fake_H.detach())
         if self.opt['train']['gan_type'] == 'gan':
             # need to forward and backward separately, since batch norm statistics differ
-            # reald
+            # real
             pred_d_real = self.netD(self.var_ref)
             l_d_real = self.cri_gan(pred_d_real, True)
             with amp.scale_loss(l_d_real, self.optimizer_D, loss_id=2) as l_d_real_scaled:
                 l_d_real_scaled.backward()
             # fake
+            pred_d_fake = self.netD(self.fake_H.detach())  # detach to avoid BP to G
             l_d_fake = self.cri_gan(pred_d_fake, False)
             with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled:
                 l_d_fake_scaled.backward()
@@ -233,10 +218,12 @@ class SRGANModel(BaseModel):
             # l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real), False)
             # l_d_total = (l_d_real + l_d_fake) / 2
             # l_d_total.backward()
+            pred_d_fake = self.netD(self.fake_H.detach()).detach()
             pred_d_real = self.netD(self.var_ref)
-            l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake.detach()), True) * 0.5
+            l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True) * 0.5
             with amp.scale_loss(l_d_real, self.optimizer_D, loss_id=2) as l_d_real_scaled:
                 l_d_real_scaled.backward()
+            pred_d_fake = self.netD(self.fake_H.detach())
             l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real.detach()), False) * 0.5
             with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled:
                 l_d_fake_scaled.backward()
@@ -258,11 +245,7 @@ class SRGANModel(BaseModel):
     def test(self):
         self.netG.eval()
         with torch.no_grad():
-            genOut = self.netG(self.var_L)
-            if type(genOut) is tuple:
-                self.fake_H = genOut[0]
-            else:
-                self.fake_H = genOut
+            self.fake_H = self.netG(self.var_L)
         self.netG.train()
 
     def get_current_log(self):
diff --git a/codes/models/archs/FlatProcessorNet_arch.py b/codes/models/archs/FlatProcessorNet_arch.py
index 374b221b..2ce1b978 100644
--- a/codes/models/archs/FlatProcessorNet_arch.py
+++ b/codes/models/archs/FlatProcessorNet_arch.py
@@ -23,12 +23,10 @@ class ReduceAnnealer(nn.Module):
         self.annealer = nn.Conv2d(number_filters*4, number_filters, 3, stride=1, padding=1, bias=True)
         self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
         arch_util.initialize_weights([self.reducer, self.annealer], .1)
-        self.bn_reduce = nn.BatchNorm2d(number_filters*4, affine=True)
-        self.bn_anneal = nn.BatchNorm2d(number_filters*4, affine=True)
 
     def forward(self, x, interpolated_trunk):
-        out = self.lrelu(self.bn_reduce(self.reducer(x)))
-        out = self.lrelu(self.bn_anneal(self.res_trunk(out)))
+        out = self.lrelu(self.reducer(x))
+        out = self.lrelu(self.res_trunk(out))
         annealed = self.lrelu(self.annealer(out)) + interpolated_trunk
         return annealed, out
 
@@ -43,13 +41,11 @@ class Assembler(nn.Module):
         self.upsampler = nn.Conv2d(number_filters, number_filters*4, 3, stride=1, padding=1, bias=True)
         self.res_trunk = arch_util.make_layer(functools.partial(arch_util.ResidualBlock, nf=number_filters*4), residual_blocks)
         self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
-        self.bn = nn.BatchNorm2d(number_filters*4, affine=True)
-        self.bn_up = nn.BatchNorm2d(number_filters*4, affine=True)
 
     def forward(self, input, skip_raw):
         out = self.pixel_shuffle(input)
-        out = self.bn_up(self.upsampler(out)) + skip_raw
-        out = self.lrelu(self.bn(self.res_trunk(out)))
+        out = self.upsampler(out) + skip_raw
+        out = self.lrelu(self.res_trunk(out))
         return out
 
 class FlatProcessorNet(nn.Module):
@@ -84,15 +80,10 @@ class FlatProcessorNet(nn.Module):
 
         # Produce assemblers for all possible downscale variants. Some may not be used.
         self.assembler1 = Assembler(nf, assembler_blocks)
-        self.assemble1_conv = nn.Conv2d(nf*4, 3, 3, stride=1, padding=1, bias=True)
         self.assembler2 = Assembler(nf, assembler_blocks)
-        self.assemble2_conv = nn.Conv2d(nf*4, 3, 3, stride=1, padding=1, bias=True)
         self.assembler3 = Assembler(nf, assembler_blocks)
-        self.assemble3_conv = nn.Conv2d(nf*4, 3, 3, stride=1, padding=1, bias=True)
         self.assembler4 = Assembler(nf, assembler_blocks)
-        self.assemble4_conv = nn.Conv2d(nf*4, 3, 3, stride=1, padding=1, bias=True)
         self.assemblers = [self.assembler1, self.assembler2, self.assembler3, self.assembler4]
-        self.assemble_convs = [self.assemble1_conv, self.assemble2_conv, self.assemble3_conv, self.assemble4_conv]
 
         # Initialization
         arch_util.initialize_weights([self.conv_first, self.conv_last], .1)
@@ -113,10 +104,8 @@ class FlatProcessorNet(nn.Module):
             raw_values.append(raw)
 
         i = -1
-        scaled_outputs = {}
         out = raw_values[-1]
         while downsamples != self.downscale:
-            scaled_outputs[int(x.shape[-1] / downsamples)] = self.assemble_convs[i](out)
             out = self.assemblers[i](out, raw_values[i-1])
             i -= 1
             downsamples = int(downsamples / 2)
@@ -126,4 +115,4 @@ class FlatProcessorNet(nn.Module):
         basis = x
         if downsamples != 1:
             basis = F.interpolate(x, scale_factor=1/downsamples, mode='bilinear', align_corners=False)
-        return basis + out, scaled_outputs
+        return basis + out
diff --git a/codes/models/archs/discriminator_vgg_arch.py b/codes/models/archs/discriminator_vgg_arch.py
index 42db4f6b..10a3ccdc 100644
--- a/codes/models/archs/discriminator_vgg_arch.py
+++ b/codes/models/archs/discriminator_vgg_arch.py
@@ -1,7 +1,6 @@
 import torch
 import torch.nn as nn
 import torchvision
-import torch.nn.functional as F
 
 
 class Discriminator_VGG_128(nn.Module):
@@ -12,17 +11,11 @@ class Discriminator_VGG_128(nn.Module):
         self.conv0_0 = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
         self.conv0_1 = nn.Conv2d(nf, nf, 4, 2, 1, bias=False)
         self.bn0_1 = nn.BatchNorm2d(nf, affine=True)
-
-        self.skipconv0 = nn.Conv2d(3, nf, 3, 1, 1, bias=True)
-
         # [64, 64, 64]
         self.conv1_0 = nn.Conv2d(nf, nf * 2, 3, 1, 1, bias=False)
         self.bn1_0 = nn.BatchNorm2d(nf * 2, affine=True)
         self.conv1_1 = nn.Conv2d(nf * 2, nf * 2, 4, 2, 1, bias=False)
         self.bn1_1 = nn.BatchNorm2d(nf * 2, affine=True)
-
-        self.skipconv1 = nn.Conv2d(3, nf*2, 3, 1, 1, bias=True)
-
         # [128, 32, 32]
         self.conv2_0 = nn.Conv2d(nf * 2, nf * 4, 3, 1, 1, bias=False)
         self.bn2_0 = nn.BatchNorm2d(nf * 4, affine=True)
@@ -45,22 +38,13 @@ class Discriminator_VGG_128(nn.Module):
         # activation function
         self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
 
-    def forward(self, x, gen_skips=None):
-        x_dim = x.size(-1)
-        if gen_skips is None:
-            gen_skips = {
-                int(x_dim/2): F.interpolate(x, scale_factor=1/2, mode='bilinear', align_corners=False),
-                int(x_dim/4): F.interpolate(x, scale_factor=1/4, mode='bilinear', align_corners=False),
-            }
-
+    def forward(self, x):
         fea = self.lrelu(self.conv0_0(x))
         fea = self.lrelu(self.bn0_1(self.conv0_1(fea)))
 
-        fea = (fea + self.skipconv0(gen_skips[x_dim/2])) / 2
         fea = self.lrelu(self.bn1_0(self.conv1_0(fea)))
         fea = self.lrelu(self.bn1_1(self.conv1_1(fea)))
 
-        fea = (fea + self.skipconv1(gen_skips[x_dim/4])) / 2
         fea = self.lrelu(self.bn2_0(self.conv2_0(fea)))
         fea = self.lrelu(self.bn2_1(self.conv2_1(fea)))
 
diff --git a/codes/options/train/train_GAN_blacked_corrupt.yml b/codes/options/train/train_GAN_blacked_corrupt.yml
index c884ab70..f2f0edd6 100644
--- a/codes/options/train/train_GAN_blacked_corrupt.yml
+++ b/codes/options/train/train_GAN_blacked_corrupt.yml
@@ -16,8 +16,8 @@ datasets:
     dataroot_LQ: E:\\4k6k\\datasets\\ultra_lowq\\for_training
     mismatched_Data_OK: true
     use_shuffle: true
-    n_workers: 8 # per GPU
-    batch_size: 32
+    n_workers: 0 # per GPU
+    batch_size: 16
     target_size: 64
     use_flip: false
     use_rot: false
@@ -34,17 +34,14 @@ network_G:
   which_model_G: FlatProcessorNet
   in_nc: 3
   out_nc: 3
-  nf: 32
-  ra_blocks: 6
-  assembler_blocks: 4
+  nf: 48
+  ra_blocks: 4
+  assembler_blocks: 3
 
 network_D:
-  which_model_D: discriminator_vgg_128
+  which_model_D: discriminator_resnet
   in_nc: 3
   nf: 64
-  #which_model_D: discriminator_resnet
-  #in_nc: 3
-  #nf: 32
 
 #### path
 path:
@@ -59,7 +56,7 @@ train:
   weight_decay_G: 0
   beta1_G: 0.9
   beta2_G: 0.99
-  lr_D: !!float 2e-4
+  lr_D: !!float 1e-4
   weight_decay_D: 0
   beta1_D: 0.9
   beta2_D: 0.99
@@ -74,11 +71,11 @@ train:
   pixel_weight: !!float 1e-2
   feature_criterion: l1
   feature_weight: 0
-  gan_type: gan  # gan | ragan
+  gan_type: ragan  # gan | ragan
   gan_weight: !!float 1e-1
 
-  D_update_ratio: 1
-  D_init_iters: -1
+  D_update_ratio: 2
+  D_init_iters: 1200
 
   manual_seed: 10
   val_freq: !!float 5e2