From d09ed4e5f7b7f9c8a976808974a945ad3b67c711 Mon Sep 17 00:00:00 2001 From: James Betker Date: Sun, 26 Jul 2020 22:44:24 -0600 Subject: [PATCH] Misc fixes --- codes/models/SRGAN_model.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/codes/models/SRGAN_model.py b/codes/models/SRGAN_model.py index 677160d1..67eea696 100644 --- a/codes/models/SRGAN_model.py +++ b/codes/models/SRGAN_model.py @@ -102,7 +102,7 @@ class SRGANModel(BaseModel): # D_update_ratio and D_init_iters self.D_update_ratio = train_opt['D_update_ratio'] if train_opt['D_update_ratio'] else 1 self.D_init_iters = train_opt['D_init_iters'] if train_opt['D_init_iters'] else 0 - self.G_warmup = train_opt['G_warmup'] if train_opt['G_warmup'] else 0 + self.G_warmup = train_opt['G_warmup'] if train_opt['G_warmup'] else -1 self.D_noise_theta = train_opt['D_noise_theta_init'] if train_opt['D_noise_theta_init'] else 0 self.D_noise_final = train_opt['D_noise_final_it'] if train_opt['D_noise_final_it'] else 0 self.D_noise_theta_floor = train_opt['D_noise_theta_floor'] if train_opt['D_noise_theta_floor'] else 0 @@ -200,6 +200,8 @@ class SRGANModel(BaseModel): # GAN LQ image params self.gan_lq_img_use_prob = train_opt['gan_lowres_use_probability'] if train_opt['gan_lowres_use_probability'] else 0 + self.img_debug_steps = train_opt['img_debug_steps'] if train_opt['img_debug_steps'] else 50 + self.print_network() # print network self.load() # load G and D if needed self.load_random_corruptor() @@ -356,7 +358,7 @@ class SRGANModel(BaseModel): _t = time() # D - if self.l_gan_w > 0 and step > self.G_warmup: + if self.l_gan_w > 0 and step >= self.G_warmup: for p in self.netD.parameters(): if p.dtype != torch.int64 and p.dtype != torch.bool: p.requires_grad = True @@ -506,7 +508,7 @@ class SRGANModel(BaseModel): _t = time() # Log sample images from first microbatch. - if step % 50 == 0: + if step % self.img_debug_steps == 0: sample_save_path = os.path.join(self.opt['path']['models'], "..", "temp") os.makedirs(os.path.join(sample_save_path, "hr"), exist_ok=True) os.makedirs(os.path.join(sample_save_path, "lr"), exist_ok=True) @@ -524,7 +526,7 @@ class SRGANModel(BaseModel): utils.save_image(self.pix[i].cpu(), os.path.join(sample_save_path, "pix", "%05i_%02i.png" % (step, i))) utils.save_image(self.fake_GenOut[i].cpu(), os.path.join(sample_save_path, "gen", "%05i_%02i.png" % (step, i))) utils.save_image(self.fea_GenOut[i].cpu(), os.path.join(sample_save_path, "gen_fea", "%05i_%02i.png" % (step, i))) - if self.l_gan_w > 0 and step > self.G_warmup and 'pixgan' in self.opt['train']['gan_type']: + if self.l_gan_w > 0 and step >= self.G_warmup and 'pixgan' in self.opt['train']['gan_type']: utils.save_image(var_ref_skips[i].cpu(), os.path.join(sample_save_path, "ref", "%05i_%02i.png" % (step, i))) utils.save_image(self.fake_H[i], os.path.join(sample_save_path, "disc_fake", "fake%05i_%02i.png" % (step, i))) utils.save_image(F.interpolate(fake_disc_images[i], scale_factor=4), os.path.join(sample_save_path, "disc", "fake%05i_%02i.png" % (step, i))) @@ -545,7 +547,7 @@ class SRGANModel(BaseModel): self.add_log_entry('l_d_fea_real', l_d_fea_real.item() * self.mega_batch_factor) self.add_log_entry('l_d_fake_total', l_d_fake.item() * self.mega_batch_factor) self.add_log_entry('l_d_real_total', l_d_real.item() * self.mega_batch_factor) - if self.l_gan_w > 0 and step > self.G_warmup: + if self.l_gan_w > 0 and step >= self.G_warmup: self.add_log_entry('l_d_real', l_d_real_log.item()) self.add_log_entry('l_d_fake', l_d_fake_log.item()) self.add_log_entry('D_fake', torch.mean(pred_d_fake.detach()))