Misc fixes

This commit is contained in:
James Betker 2020-07-26 22:44:24 -06:00
parent c54784ae9e
commit d09ed4e5f7

View File

@ -102,7 +102,7 @@ class SRGANModel(BaseModel):
# D_update_ratio and D_init_iters
self.D_update_ratio = train_opt['D_update_ratio'] if train_opt['D_update_ratio'] else 1
self.D_init_iters = train_opt['D_init_iters'] if train_opt['D_init_iters'] else 0
self.G_warmup = train_opt['G_warmup'] if train_opt['G_warmup'] else 0
self.G_warmup = train_opt['G_warmup'] if train_opt['G_warmup'] else -1
self.D_noise_theta = train_opt['D_noise_theta_init'] if train_opt['D_noise_theta_init'] else 0
self.D_noise_final = train_opt['D_noise_final_it'] if train_opt['D_noise_final_it'] else 0
self.D_noise_theta_floor = train_opt['D_noise_theta_floor'] if train_opt['D_noise_theta_floor'] else 0
@ -200,6 +200,8 @@ class SRGANModel(BaseModel):
# GAN LQ image params
self.gan_lq_img_use_prob = train_opt['gan_lowres_use_probability'] if train_opt['gan_lowres_use_probability'] else 0
self.img_debug_steps = train_opt['img_debug_steps'] if train_opt['img_debug_steps'] else 50
self.print_network() # print network
self.load() # load G and D if needed
self.load_random_corruptor()
@ -356,7 +358,7 @@ class SRGANModel(BaseModel):
_t = time()
# D
if self.l_gan_w > 0 and step > self.G_warmup:
if self.l_gan_w > 0 and step >= self.G_warmup:
for p in self.netD.parameters():
if p.dtype != torch.int64 and p.dtype != torch.bool:
p.requires_grad = True
@ -506,7 +508,7 @@ class SRGANModel(BaseModel):
_t = time()
# Log sample images from first microbatch.
if step % 50 == 0:
if step % self.img_debug_steps == 0:
sample_save_path = os.path.join(self.opt['path']['models'], "..", "temp")
os.makedirs(os.path.join(sample_save_path, "hr"), exist_ok=True)
os.makedirs(os.path.join(sample_save_path, "lr"), exist_ok=True)
@ -524,7 +526,7 @@ class SRGANModel(BaseModel):
utils.save_image(self.pix[i].cpu(), os.path.join(sample_save_path, "pix", "%05i_%02i.png" % (step, i)))
utils.save_image(self.fake_GenOut[i].cpu(), os.path.join(sample_save_path, "gen", "%05i_%02i.png" % (step, i)))
utils.save_image(self.fea_GenOut[i].cpu(), os.path.join(sample_save_path, "gen_fea", "%05i_%02i.png" % (step, i)))
if self.l_gan_w > 0 and step > self.G_warmup and 'pixgan' in self.opt['train']['gan_type']:
if self.l_gan_w > 0 and step >= self.G_warmup and 'pixgan' in self.opt['train']['gan_type']:
utils.save_image(var_ref_skips[i].cpu(), os.path.join(sample_save_path, "ref", "%05i_%02i.png" % (step, i)))
utils.save_image(self.fake_H[i], os.path.join(sample_save_path, "disc_fake", "fake%05i_%02i.png" % (step, i)))
utils.save_image(F.interpolate(fake_disc_images[i], scale_factor=4), os.path.join(sample_save_path, "disc", "fake%05i_%02i.png" % (step, i)))
@ -545,7 +547,7 @@ class SRGANModel(BaseModel):
self.add_log_entry('l_d_fea_real', l_d_fea_real.item() * self.mega_batch_factor)
self.add_log_entry('l_d_fake_total', l_d_fake.item() * self.mega_batch_factor)
self.add_log_entry('l_d_real_total', l_d_real.item() * self.mega_batch_factor)
if self.l_gan_w > 0 and step > self.G_warmup:
if self.l_gan_w > 0 and step >= self.G_warmup:
self.add_log_entry('l_d_real', l_d_real_log.item())
self.add_log_entry('l_d_fake', l_d_fake_log.item())
self.add_log_entry('D_fake', torch.mean(pred_d_fake.detach()))