Add G_warmup

Let the Generator get to a point where it is at least competing with the discriminator before firing off.

Backwards from most GAN architectures, but this one is a bit different from most.
This commit is contained in:
James Betker 2020-07-05 21:58:35 -06:00
parent a47a5dca43
commit 909007ee6a

View File

@ -89,6 +89,7 @@ class SRGANModel(BaseModel):
# D_update_ratio and D_init_iters # D_update_ratio and D_init_iters
self.D_update_ratio = train_opt['D_update_ratio'] if train_opt['D_update_ratio'] else 1 self.D_update_ratio = train_opt['D_update_ratio'] if train_opt['D_update_ratio'] else 1
self.D_init_iters = train_opt['D_init_iters'] if train_opt['D_init_iters'] else 0 self.D_init_iters = train_opt['D_init_iters'] if train_opt['D_init_iters'] else 0
self.G_warmup = train_opt['G_warmup'] if train_opt['G_warmup'] else 0
self.D_noise_theta = train_opt['D_noise_theta_init'] if train_opt['D_noise_theta_init'] else 0 self.D_noise_theta = train_opt['D_noise_theta_init'] if train_opt['D_noise_theta_init'] else 0
self.D_noise_final = train_opt['D_noise_final_it'] if train_opt['D_noise_final_it'] else 0 self.D_noise_final = train_opt['D_noise_final_it'] if train_opt['D_noise_final_it'] else 0
self.D_noise_theta_floor = train_opt['D_noise_theta_floor'] if train_opt['D_noise_theta_floor'] else 0 self.D_noise_theta_floor = train_opt['D_noise_theta_floor'] if train_opt['D_noise_theta_floor'] else 0
@ -300,7 +301,7 @@ class SRGANModel(BaseModel):
_t = time() _t = time()
# D # D
if self.l_gan_w > 0: if self.l_gan_w > 0 and step > self.G_warmup:
for p in self.netD.parameters(): for p in self.netD.parameters():
p.requires_grad = True p.requires_grad = True
@ -413,7 +414,7 @@ class SRGANModel(BaseModel):
if self.l_gan_w > 0: if self.l_gan_w > 0:
self.add_log_entry('l_g_gan', l_g_gan.item()) self.add_log_entry('l_g_gan', l_g_gan.item())
self.add_log_entry('l_g_total', l_g_total.item() * self.mega_batch_factor) self.add_log_entry('l_g_total', l_g_total.item() * self.mega_batch_factor)
if self.l_gan_w > 0: if self.l_gan_w > 0 and step > self.G_warmup:
self.add_log_entry('l_d_real', l_d_real.item() * self.mega_batch_factor) self.add_log_entry('l_d_real', l_d_real.item() * self.mega_batch_factor)
self.add_log_entry('l_d_fake', l_d_fake.item() * self.mega_batch_factor) self.add_log_entry('l_d_fake', l_d_fake.item() * self.mega_batch_factor)
self.add_log_entry('D_fake', torch.mean(pred_d_fake.detach())) self.add_log_entry('D_fake', torch.mean(pred_d_fake.detach()))