Add a noise floor to th discriminator noise factor

This commit is contained in:
James Betker 2020-05-13 09:19:22 -06:00
parent 5d1b4caabf
commit fc3ec8e3a2

View File

@ -80,6 +80,7 @@ class SRGANModel(BaseModel):
self.D_init_iters = train_opt['D_init_iters'] if train_opt['D_init_iters'] else 0 self.D_init_iters = train_opt['D_init_iters'] if train_opt['D_init_iters'] else 0
self.D_noise_theta = train_opt['D_noise_theta_init'] if train_opt['D_noise_theta_init'] else 0 self.D_noise_theta = train_opt['D_noise_theta_init'] if train_opt['D_noise_theta_init'] else 0
self.D_noise_final = train_opt['D_noise_final_it'] if train_opt['D_noise_final_it'] else 0 self.D_noise_final = train_opt['D_noise_final_it'] if train_opt['D_noise_final_it'] else 0
self.D_noise_theta_floor = train_opt['D_noise_theta_floor'] if train_opt['D_noise_theta_floor'] else 0
# optimizers # optimizers
# G # G
@ -171,7 +172,7 @@ class SRGANModel(BaseModel):
if step >= self.D_noise_final: if step >= self.D_noise_final:
noise_theta = 0 noise_theta = 0
else: else:
noise_theta = self.D_noise_theta * (self.D_noise_final - step) / self.D_noise_final noise_theta = (self.D_noise_theta - self.D_noise_theta_floor) * (self.D_noise_final - step) / self.D_noise_final + self.D_noise_theta_floor
self.fake_GenOut = [] self.fake_GenOut = []
var_ref_skips = [] var_ref_skips = []
@ -287,6 +288,7 @@ class SRGANModel(BaseModel):
utils.save_image(self.fake_GenOut[i][0].cpu().detach(), os.path.join("temp/gen", "%05i_%02i.png" % (step, i))) utils.save_image(self.fake_GenOut[i][0].cpu().detach(), os.path.join("temp/gen", "%05i_%02i.png" % (step, i)))
utils.save_image(self.fake_GenOut[i][1].cpu().detach(), os.path.join("temp/genmr", "%05i_%02i.png" % (step, i))) utils.save_image(self.fake_GenOut[i][1].cpu().detach(), os.path.join("temp/genmr", "%05i_%02i.png" % (step, i)))
utils.save_image(self.fake_GenOut[i][2].cpu().detach(), os.path.join("temp/genlr", "%05i_%02i.png" % (step, i))) utils.save_image(self.fake_GenOut[i][2].cpu().detach(), os.path.join("temp/genlr", "%05i_%02i.png" % (step, i)))
utils.save_image(var_ref_skips[i][0].cpu().detach(), os.path.join("temp/ref", "hi_%05i_%02i.png" % (step, i)))
utils.save_image(var_ref_skips[i][1].cpu().detach(), os.path.join("temp/ref", "med_%05i_%02i.png" % (step, i))) utils.save_image(var_ref_skips[i][1].cpu().detach(), os.path.join("temp/ref", "med_%05i_%02i.png" % (step, i)))
utils.save_image(var_ref_skips[i][2].cpu().detach(), os.path.join("temp/ref", "low_%05i_%02i.png" % (step, i))) utils.save_image(var_ref_skips[i][2].cpu().detach(), os.path.join("temp/ref", "low_%05i_%02i.png" % (step, i)))
else: else: