Add G_warmup
Let the Generator get to a point where it is at least competing with the discriminator before firing off. Backwards from most GAN architectures, but this one is a bit different from most.
This commit is contained in:
parent
a47a5dca43
commit
909007ee6a
|
@ -89,6 +89,7 @@ class SRGANModel(BaseModel):
|
||||||
# D_update_ratio and D_init_iters
|
# D_update_ratio and D_init_iters
|
||||||
self.D_update_ratio = train_opt['D_update_ratio'] if train_opt['D_update_ratio'] else 1
|
self.D_update_ratio = train_opt['D_update_ratio'] if train_opt['D_update_ratio'] else 1
|
||||||
self.D_init_iters = train_opt['D_init_iters'] if train_opt['D_init_iters'] else 0
|
self.D_init_iters = train_opt['D_init_iters'] if train_opt['D_init_iters'] else 0
|
||||||
|
self.G_warmup = train_opt['G_warmup'] if train_opt['G_warmup'] else 0
|
||||||
self.D_noise_theta = train_opt['D_noise_theta_init'] if train_opt['D_noise_theta_init'] else 0
|
self.D_noise_theta = train_opt['D_noise_theta_init'] if train_opt['D_noise_theta_init'] else 0
|
||||||
self.D_noise_final = train_opt['D_noise_final_it'] if train_opt['D_noise_final_it'] else 0
|
self.D_noise_final = train_opt['D_noise_final_it'] if train_opt['D_noise_final_it'] else 0
|
||||||
self.D_noise_theta_floor = train_opt['D_noise_theta_floor'] if train_opt['D_noise_theta_floor'] else 0
|
self.D_noise_theta_floor = train_opt['D_noise_theta_floor'] if train_opt['D_noise_theta_floor'] else 0
|
||||||
|
@ -300,7 +301,7 @@ class SRGANModel(BaseModel):
|
||||||
_t = time()
|
_t = time()
|
||||||
|
|
||||||
# D
|
# D
|
||||||
if self.l_gan_w > 0:
|
if self.l_gan_w > 0 and step > self.G_warmup:
|
||||||
for p in self.netD.parameters():
|
for p in self.netD.parameters():
|
||||||
p.requires_grad = True
|
p.requires_grad = True
|
||||||
|
|
||||||
|
@ -413,7 +414,7 @@ class SRGANModel(BaseModel):
|
||||||
if self.l_gan_w > 0:
|
if self.l_gan_w > 0:
|
||||||
self.add_log_entry('l_g_gan', l_g_gan.item())
|
self.add_log_entry('l_g_gan', l_g_gan.item())
|
||||||
self.add_log_entry('l_g_total', l_g_total.item() * self.mega_batch_factor)
|
self.add_log_entry('l_g_total', l_g_total.item() * self.mega_batch_factor)
|
||||||
if self.l_gan_w > 0:
|
if self.l_gan_w > 0 and step > self.G_warmup:
|
||||||
self.add_log_entry('l_d_real', l_d_real.item() * self.mega_batch_factor)
|
self.add_log_entry('l_d_real', l_d_real.item() * self.mega_batch_factor)
|
||||||
self.add_log_entry('l_d_fake', l_d_fake.item() * self.mega_batch_factor)
|
self.add_log_entry('l_d_fake', l_d_fake.item() * self.mega_batch_factor)
|
||||||
self.add_log_entry('D_fake', torch.mean(pred_d_fake.detach()))
|
self.add_log_entry('D_fake', torch.mean(pred_d_fake.detach()))
|
||||||
|
|
Loading…
Reference in New Issue
Block a user