From a38dd62489a527ff0eebcdc7dda6adf97fd60d1a Mon Sep 17 00:00:00 2001
From: James Betker <jbetker@gmail.com>
Date: Mon, 1 Jun 2020 15:09:10 -0600
Subject: [PATCH] Only train discriminator/gan losses when gan_w > 0

---
 codes/models/SRGAN_model.py | 108 ++++++++++++++++++------------------
 1 file changed, 55 insertions(+), 53 deletions(-)

diff --git a/codes/models/SRGAN_model.py b/codes/models/SRGAN_model.py
index 66fc05a6..19665e67 100644
--- a/codes/models/SRGAN_model.py
+++ b/codes/models/SRGAN_model.py
@@ -239,16 +239,17 @@ class SRGANModel(BaseModel):
                     if step % self.l_fea_w_decay_steps == 0:
                         self.l_fea_w = max(self.l_fea_w_minimum, self.l_fea_w * self.l_fea_w_decay)
 
-                if self.opt['train']['gan_type'] == 'gan':
-                    pred_g_fake = self.netD(fake_GenOut)
-                    l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True)
-                elif self.opt['train']['gan_type'] == 'ragan':
-                    pred_d_real = self.netD(var_ref).detach()
-                    pred_g_fake = self.netD(fake_GenOut)
-                    l_g_gan = self.l_gan_w * (
-                        self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) +
-                        self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2
-                l_g_total += l_g_gan
+                if self.l_gan_w > 0:
+                    if self.opt['train']['gan_type'] == 'gan':
+                        pred_g_fake = self.netD(fake_GenOut)
+                        l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True)
+                    elif self.opt['train']['gan_type'] == 'ragan':
+                        pred_d_real = self.netD(var_ref).detach()
+                        pred_g_fake = self.netD(fake_GenOut)
+                        l_g_gan = self.l_gan_w * (
+                            self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) +
+                            self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2
+                    l_g_total += l_g_gan
 
                 # Scale the loss down by the batch factor.
                 l_g_total = l_g_total / self.mega_batch_factor
@@ -258,51 +259,52 @@ class SRGANModel(BaseModel):
         self.optimizer_G.step()
 
         # D
-        for p in self.netD.parameters():
-            p.requires_grad = True
+        if self.l_gan_w > 0:
+            for p in self.netD.parameters():
+                p.requires_grad = True
 
-        noise = torch.randn_like(var_ref[0]) * noise_theta
-        noise.to(self.device)
-        self.optimizer_D.zero_grad()
-        for var_L, var_H, var_ref, pix in zip(self.var_L, self.var_H, var_ref_skips, self.pix):
-            # Re-compute generator outputs (post-update).
-            with torch.no_grad():
-                fake_H = self.netG(var_L)
-                # The following line detaches all generator outputs that are not None.
-                fake_H = tuple([(x.detach() if x is not None else None) for x in list(fake_H)])
+            noise = torch.randn_like(var_ref[0]) * noise_theta
+            noise.to(self.device)
+            self.optimizer_D.zero_grad()
+            for var_L, var_H, var_ref, pix in zip(self.var_L, self.var_H, var_ref_skips, self.pix):
+                # Re-compute generator outputs (post-update).
+                with torch.no_grad():
+                    fake_H = self.netG(var_L)
+                    # The following line detaches all generator outputs that are not None.
+                    fake_H = tuple([(x.detach() if x is not None else None) for x in list(fake_H)])
 
-            # Apply noise to the inputs to slow discriminator convergence.
-            var_ref = (var_ref[0] + noise,) + var_ref[1:]
-            fake_H = (fake_H[0] + noise,) + fake_H[1:]
-            if self.opt['train']['gan_type'] == 'gan':
-                # need to forward and backward separately, since batch norm statistics differ
-                # real
-                pred_d_real = self.netD(var_ref)
-                l_d_real = self.cri_gan(pred_d_real, True) / self.mega_batch_factor
-                with amp.scale_loss(l_d_real, self.optimizer_D, loss_id=2) as l_d_real_scaled:
-                    l_d_real_scaled.backward()
-                # fake
-                pred_d_fake = self.netD(fake_H)
-                l_d_fake = self.cri_gan(pred_d_fake, False) / self.mega_batch_factor
-                with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled:
-                    l_d_fake_scaled.backward()
-            elif self.opt['train']['gan_type'] == 'ragan':
-                # pred_d_real = self.netD(var_ref)
-                # pred_d_fake = self.netD(fake_H.detach())  # detach to avoid BP to G
-                # l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True)
-                # l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real), False)
-                # l_d_total = (l_d_real + l_d_fake) / 2
-                # l_d_total.backward()
-                pred_d_fake = self.netD(fake_H).detach()
-                pred_d_real = self.netD(var_ref)
-                l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True) * 0.5 / self.mega_batch_factor
-                with amp.scale_loss(l_d_real, self.optimizer_D, loss_id=2) as l_d_real_scaled:
-                    l_d_real_scaled.backward()
-                pred_d_fake = self.netD(fake_H)
-                l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real.detach()), False) * 0.5 / self.mega_batch_factor
-                with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled:
-                    l_d_fake_scaled.backward()
-        self.optimizer_D.step()
+                # Apply noise to the inputs to slow discriminator convergence.
+                var_ref = (var_ref[0] + noise,) + var_ref[1:]
+                fake_H = (fake_H[0] + noise,) + fake_H[1:]
+                if self.opt['train']['gan_type'] == 'gan':
+                    # need to forward and backward separately, since batch norm statistics differ
+                    # real
+                    pred_d_real = self.netD(var_ref)
+                    l_d_real = self.cri_gan(pred_d_real, True) / self.mega_batch_factor
+                    with amp.scale_loss(l_d_real, self.optimizer_D, loss_id=2) as l_d_real_scaled:
+                        l_d_real_scaled.backward()
+                    # fake
+                    pred_d_fake = self.netD(fake_H)
+                    l_d_fake = self.cri_gan(pred_d_fake, False) / self.mega_batch_factor
+                    with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled:
+                        l_d_fake_scaled.backward()
+                elif self.opt['train']['gan_type'] == 'ragan':
+                    # pred_d_real = self.netD(var_ref)
+                    # pred_d_fake = self.netD(fake_H.detach())  # detach to avoid BP to G
+                    # l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True)
+                    # l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real), False)
+                    # l_d_total = (l_d_real + l_d_fake) / 2
+                    # l_d_total.backward()
+                    pred_d_fake = self.netD(fake_H).detach()
+                    pred_d_real = self.netD(var_ref)
+                    l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True) * 0.5 / self.mega_batch_factor
+                    with amp.scale_loss(l_d_real, self.optimizer_D, loss_id=2) as l_d_real_scaled:
+                        l_d_real_scaled.backward()
+                    pred_d_fake = self.netD(fake_H)
+                    l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real.detach()), False) * 0.5 / self.mega_batch_factor
+                    with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled:
+                        l_d_fake_scaled.backward()
+            self.optimizer_D.step()
 
         # Log sample images from first microbatch.
         if step % 50 == 0: