Misc fixes
This commit is contained in:
parent
c54784ae9e
commit
d09ed4e5f7
|
@ -102,7 +102,7 @@ class SRGANModel(BaseModel):
|
|||
# D_update_ratio and D_init_iters
|
||||
self.D_update_ratio = train_opt['D_update_ratio'] if train_opt['D_update_ratio'] else 1
|
||||
self.D_init_iters = train_opt['D_init_iters'] if train_opt['D_init_iters'] else 0
|
||||
self.G_warmup = train_opt['G_warmup'] if train_opt['G_warmup'] else 0
|
||||
self.G_warmup = train_opt['G_warmup'] if train_opt['G_warmup'] else -1
|
||||
self.D_noise_theta = train_opt['D_noise_theta_init'] if train_opt['D_noise_theta_init'] else 0
|
||||
self.D_noise_final = train_opt['D_noise_final_it'] if train_opt['D_noise_final_it'] else 0
|
||||
self.D_noise_theta_floor = train_opt['D_noise_theta_floor'] if train_opt['D_noise_theta_floor'] else 0
|
||||
|
@ -200,6 +200,8 @@ class SRGANModel(BaseModel):
|
|||
# GAN LQ image params
|
||||
self.gan_lq_img_use_prob = train_opt['gan_lowres_use_probability'] if train_opt['gan_lowres_use_probability'] else 0
|
||||
|
||||
self.img_debug_steps = train_opt['img_debug_steps'] if train_opt['img_debug_steps'] else 50
|
||||
|
||||
self.print_network() # print network
|
||||
self.load() # load G and D if needed
|
||||
self.load_random_corruptor()
|
||||
|
@ -356,7 +358,7 @@ class SRGANModel(BaseModel):
|
|||
_t = time()
|
||||
|
||||
# D
|
||||
if self.l_gan_w > 0 and step > self.G_warmup:
|
||||
if self.l_gan_w > 0 and step >= self.G_warmup:
|
||||
for p in self.netD.parameters():
|
||||
if p.dtype != torch.int64 and p.dtype != torch.bool:
|
||||
p.requires_grad = True
|
||||
|
@ -506,7 +508,7 @@ class SRGANModel(BaseModel):
|
|||
_t = time()
|
||||
|
||||
# Log sample images from first microbatch.
|
||||
if step % 50 == 0:
|
||||
if step % self.img_debug_steps == 0:
|
||||
sample_save_path = os.path.join(self.opt['path']['models'], "..", "temp")
|
||||
os.makedirs(os.path.join(sample_save_path, "hr"), exist_ok=True)
|
||||
os.makedirs(os.path.join(sample_save_path, "lr"), exist_ok=True)
|
||||
|
@ -524,7 +526,7 @@ class SRGANModel(BaseModel):
|
|||
utils.save_image(self.pix[i].cpu(), os.path.join(sample_save_path, "pix", "%05i_%02i.png" % (step, i)))
|
||||
utils.save_image(self.fake_GenOut[i].cpu(), os.path.join(sample_save_path, "gen", "%05i_%02i.png" % (step, i)))
|
||||
utils.save_image(self.fea_GenOut[i].cpu(), os.path.join(sample_save_path, "gen_fea", "%05i_%02i.png" % (step, i)))
|
||||
if self.l_gan_w > 0 and step > self.G_warmup and 'pixgan' in self.opt['train']['gan_type']:
|
||||
if self.l_gan_w > 0 and step >= self.G_warmup and 'pixgan' in self.opt['train']['gan_type']:
|
||||
utils.save_image(var_ref_skips[i].cpu(), os.path.join(sample_save_path, "ref", "%05i_%02i.png" % (step, i)))
|
||||
utils.save_image(self.fake_H[i], os.path.join(sample_save_path, "disc_fake", "fake%05i_%02i.png" % (step, i)))
|
||||
utils.save_image(F.interpolate(fake_disc_images[i], scale_factor=4), os.path.join(sample_save_path, "disc", "fake%05i_%02i.png" % (step, i)))
|
||||
|
@ -545,7 +547,7 @@ class SRGANModel(BaseModel):
|
|||
self.add_log_entry('l_d_fea_real', l_d_fea_real.item() * self.mega_batch_factor)
|
||||
self.add_log_entry('l_d_fake_total', l_d_fake.item() * self.mega_batch_factor)
|
||||
self.add_log_entry('l_d_real_total', l_d_real.item() * self.mega_batch_factor)
|
||||
if self.l_gan_w > 0 and step > self.G_warmup:
|
||||
if self.l_gan_w > 0 and step >= self.G_warmup:
|
||||
self.add_log_entry('l_d_real', l_d_real_log.item())
|
||||
self.add_log_entry('l_d_fake', l_d_fake_log.item())
|
||||
self.add_log_entry('D_fake', torch.mean(pred_d_fake.detach()))
|
||||
|
|
Loading…
Reference in New Issue
Block a user