From 4bfbdaf94fc5a4252beff4fcbdfb16fc77a21ff4 Mon Sep 17 00:00:00 2001 From: James Betker Date: Tue, 4 Aug 2020 11:28:52 -0600 Subject: [PATCH] Don't recompute generator outputs for D in standard operation Should significantly improve training performance with negligible results differences. --- .../compute_fdpl_perceptual_weights.py | 2 +- codes/models/SRGAN_model.py | 29 +++++++++---------- codes/models/archs/ProgressiveSrg_arch.py | 2 +- 3 files changed, 16 insertions(+), 17 deletions(-) diff --git a/codes/data_scripts/compute_fdpl_perceptual_weights.py b/codes/data_scripts/compute_fdpl_perceptual_weights.py index ebffbf13..42d4ce2f 100644 --- a/codes/data_scripts/compute_fdpl_perceptual_weights.py +++ b/codes/data_scripts/compute_fdpl_perceptual_weights.py @@ -47,7 +47,7 @@ if __name__ == '__main__': im = rgb2ycbcr(train_data['GT'].double()) im_LR = rgb2ycbcr(F.interpolate(train_data['LQ'].double(), size=im.shape[2:], - mode="bicubic")) + mode="bicubic", align_corners=False)) patches_hr = extract_patches_2d(img=im, patch_shape=(patch_size,patch_size), batch_first=True) patches_hr = dct_2d(patches_hr, norm='ortho') patches_lr = extract_patches_2d(img=im_LR, patch_shape=(patch_size,patch_size), batch_first=True) diff --git a/codes/models/SRGAN_model.py b/codes/models/SRGAN_model.py index 11625dbf..83c428a5 100644 --- a/codes/models/SRGAN_model.py +++ b/codes/models/SRGAN_model.py @@ -506,22 +506,21 @@ class SRGANModel(BaseModel): noise.to(self.device) real_disc_images = [] fake_disc_images = [] - for var_L, var_LGAN, var_H, var_ref, pix in zip(self.var_L, self.gan_img, self.var_H, self.var_ref, self.pix): + for fake_GenOut, var_LGAN, var_H, var_ref, pix in zip(self.fake_GenOut, self.gan_img, self.var_H, self.var_ref, self.pix): if random.random() > self.gan_lq_img_use_prob: - gen_input = var_L + fake_H = fake_GenOut.clone().detach().requires_grad_(False) else: - gen_input = var_LGAN - # Re-compute generator outputs (post-update). - with torch.no_grad(): - if self.spsr_enabled: - _, fake_H, _ = self.netG(gen_input) - else: - _, fake_H = self.netG(gen_input) - fake_H = fake_H.detach() + # Re-compute generator outputs with the GAN inputs. + with torch.no_grad(): + if self.spsr_enabled: + _, fake_H, _ = self.netG(var_LGAN) + else: + _, fake_H = self.netG(var_LGAN) + fake_H = fake_H.detach() - if _profile: - print("Gen forward for disc %f" % (time() - _t,)) - _t = time() + if _profile: + print("Gen forward for disc %f" % (time() - _t,)) + _t = time() # Apply noise to the inputs to slow discriminator convergence. var_ref = var_ref + noise @@ -583,8 +582,8 @@ class SRGANModel(BaseModel): fake[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = 1.0 # Interpolate down to the dimensionality that the discriminator uses. - real = F.interpolate(real, size=disc_output_shape[2:], mode="bilinear") - fake = F.interpolate(fake, size=disc_output_shape[2:], mode="bilinear") + real = F.interpolate(real, size=disc_output_shape[2:], mode="bilinear", align_corners=False) + fake = F.interpolate(fake, size=disc_output_shape[2:], mode="bilinear", align_corners=False) # We're also assuming that this is exactly how the flattened discriminator output is generated. real = real.view(-1, 1) diff --git a/codes/models/archs/ProgressiveSrg_arch.py b/codes/models/archs/ProgressiveSrg_arch.py index 66d48586..b8c615e3 100644 --- a/codes/models/archs/ProgressiveSrg_arch.py +++ b/codes/models/archs/ProgressiveSrg_arch.py @@ -266,7 +266,7 @@ class GrowingUnetDiscBase(nn.Module): disc_age = self.latest_step - self.progressive_schedule[i] fade_in = min(1, disc_age * self.growth_fade_in_per_step) mean_weight += fade_in - base_loss += F.interpolate(l, size=res, mode="bilinear") * fade_in + base_loss += F.interpolate(l, size=res, mode="bilinear", align_corners=False) * fade_in base_loss /= mean_weight return base_loss.view(-1, 1)