From 06d18343f73f37068463783690f40091a5957547 Mon Sep 17 00:00:00 2001 From: James Betker Date: Tue, 12 May 2020 16:25:38 -0600 Subject: [PATCH] Allow noise to be added to discriminator inputs --- codes/models/SRGAN_model.py | 26 +++++++++++++++++--------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/codes/models/SRGAN_model.py b/codes/models/SRGAN_model.py index cd5387f1..bb834b68 100644 --- a/codes/models/SRGAN_model.py +++ b/codes/models/SRGAN_model.py @@ -78,6 +78,8 @@ class SRGANModel(BaseModel): # D_update_ratio and D_init_iters self.D_update_ratio = train_opt['D_update_ratio'] if train_opt['D_update_ratio'] else 1 self.D_init_iters = train_opt['D_init_iters'] if train_opt['D_init_iters'] else 0 + self.D_noise_theta = train_opt['D_noise_theta_init'] if train_opt['D_noise_theta_init'] else 0 + self.D_noise_final = train_opt['D_noise_final_it'] if train_opt['D_noise_final_it'] else 0 # optimizers # G @@ -165,7 +167,14 @@ class SRGANModel(BaseModel): for p in self.netG.parameters(): p.requires_grad = False + # Calculate a standard deviation for the gaussian noise to be applied to the discriminator, termed noise-theta. + if step >= self.D_noise_final: + noise_theta = 0 + else: + noise_theta = self.D_noise_theta * (self.D_noise_final - step) / self.D_noise_final + self.fake_GenOut = [] + var_ref_skips = [] for var_L, var_H, var_ref, pix in zip(self.var_L, self.var_H, self.var_ref, self.pix): fake_GenOut = self.netG(var_L) @@ -177,9 +186,11 @@ class SRGANModel(BaseModel): self.fake_GenOut.append((fake_GenOut[0].detach(), fake_GenOut[1].detach(), fake_GenOut[2].detach())) + var_ref = (var_ref,) + self.create_artificial_skips(var_H) else: gen_img = fake_GenOut self.fake_GenOut.append(fake_GenOut.detach()) + var_ref_skips.append(var_ref) l_g_total = 0 if step % self.D_update_ratio == 0 and step > self.D_init_iters: @@ -219,17 +230,13 @@ class SRGANModel(BaseModel): for p in self.netD.parameters(): p.requires_grad = True - # Convert var_ref to have the same output format as the generator. This generally means interpolating the - # HR images to have the same output dimensions as each generator skip connection. - if isinstance(self.fake_GenOut[0], tuple): - var_ref_skips = [] - for ref, hi_res in zip(self.var_ref, self.var_H): - var_ref_skips.append((ref,) + self.create_artificial_skips(hi_res)) - else: - var_ref_skips = self.var_ref - + noise = torch.randn_like(var_ref[0]) * noise_theta + noise.to(self.device) self.optimizer_D.zero_grad() for var_L, var_H, var_ref, pix, fake_H in zip(self.var_L, self.var_H, var_ref_skips, self.pix, self.fake_GenOut): + # Apply noise to the inputs to slow discriminator convergence. + var_ref = (var_ref[0] + noise,) + var_ref[1:] + fake_H = (fake_H[0] + noise,) + fake_H[1:] if self.opt['train']['gan_type'] == 'gan': # need to forward and backward separately, since batch norm statistics differ # real @@ -297,6 +304,7 @@ class SRGANModel(BaseModel): self.add_log_entry('l_d_real', l_d_real.item() * self.mega_batch_factor) self.add_log_entry('l_d_fake', l_d_fake.item() * self.mega_batch_factor) self.add_log_entry('D_fake', torch.mean(pred_d_fake.detach())) + self.add_log_entry('noise_theta', noise_theta) # Allows the log to serve as an easy-to-use rotating buffer. def add_log_entry(self, key, value):