diff --git a/codes/models/SRGAN_model.py b/codes/models/SRGAN_model.py index 6de9664c..427b027c 100644 --- a/codes/models/SRGAN_model.py +++ b/codes/models/SRGAN_model.py @@ -246,7 +246,6 @@ class SRGANModel(BaseModel): else: gen_img = fake_GenOut self.fake_GenOut.append(fake_GenOut.detach()) - var_ref_skips.append(var_ref[0].detach()) l_g_total = 0 if step % self.D_update_ratio == 0 and step > self.D_init_iters: @@ -308,7 +307,7 @@ class SRGANModel(BaseModel): noise = torch.randn_like(var_ref[0]) * noise_theta noise.to(self.device) self.optimizer_D.zero_grad() - for var_L, var_H, var_ref, pix in zip(self.var_L, self.var_H, var_ref_skips, self.pix): + for var_L, var_H, var_ref, pix in zip(self.var_L, self.var_H, self.var_ref, self.pix): # Re-compute generator outputs (post-update). with torch.no_grad(): fake_H = self.netG(var_L) @@ -320,9 +319,8 @@ class SRGANModel(BaseModel): _t = time() # Apply noise to the inputs to slow discriminator convergence. - var_ref = (var_ref[0] + noise,) + var_ref = (var_ref + noise,) fake_H = (fake_H[0] + noise,) + fake_H[1:] - self.fake_H.append(fake_H[0].detach()) if self.opt['train']['gan_type'] == 'gan': # need to forward and backward separately, since batch norm statistics differ # real @@ -351,9 +349,9 @@ class SRGANModel(BaseModel): if random.random() > .25: # Make the swap across fake_H and var_ref - SWAP_MAX_DIM = var_ref[0].shape[2] // (2 * PIXDISC_MAX_REDUCTION) - 1 + SWAP_MAX_DIM = var_ref[0].shape[2] // (2 * PIXDISC_MAX_REDUCTION) assert SWAP_MAX_DIM > 0 - swap_x, swap_y = random.randint(0, SWAP_MAX_DIM+1) * PIXDISC_MAX_REDUCTION, random.randint(0, SWAP_MAX_DIM+1) * PIXDISC_MAX_REDUCTION + swap_x, swap_y = random.randint(0, SWAP_MAX_DIM) * PIXDISC_MAX_REDUCTION, random.randint(0, SWAP_MAX_DIM) * PIXDISC_MAX_REDUCTION swap_w, swap_h = random.randint(1, SWAP_MAX_DIM) * PIXDISC_MAX_REDUCTION, random.randint(1, SWAP_MAX_DIM) * PIXDISC_MAX_REDUCTION t = fake_H[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)].clone() fake_H[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = var_ref[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] @@ -406,8 +404,13 @@ class SRGANModel(BaseModel): if _profile: print("Disc forward/backward 2 (RAGAN) %f" % (time() - _t,)) _t = time() + + # Append var_ref here, so that we can inspect the alterations the disc made if pixgan + var_ref_skips.append(var_ref[0].detach()) + self.fake_H.append(fake_H[0].detach()) self.optimizer_D.step() + if _profile: print("Disc step %f" % (time() - _t,)) _t = time() @@ -432,8 +435,8 @@ class SRGANModel(BaseModel): utils.save_image(self.pix[i].cpu(), os.path.join(sample_save_path, "pix", "%05i_%02i.png" % (step, i))) if multi_gen: utils.save_image(self.fake_GenOut[i][0].cpu(), os.path.join(sample_save_path, "gen", "%05i_%02i.png" % (step, i))) - utils.save_image(var_ref_skips[i].cpu(), os.path.join(sample_save_path, "ref", "%05i_%02i.png" % (step, i))) if step % self.D_update_ratio == 0 and step > self.D_init_iters: + utils.save_image(var_ref_skips[i].cpu(), os.path.join(sample_save_path, "ref", "%05i_%02i.png" % (step, i))) utils.save_image(self.fake_H[i], os.path.join(sample_save_path, "disc_fake", "%05i_%02i.png" % (step, i))) else: utils.save_image(self.fake_GenOut[i].cpu(), os.path.join(sample_save_path, "gen", "%05i_%02i.png" % (step, i)))