diff --git a/codes/models/SRGAN_model.py b/codes/models/SRGAN_model.py index fa330a09..81537ff3 100644 --- a/codes/models/SRGAN_model.py +++ b/codes/models/SRGAN_model.py @@ -256,11 +256,13 @@ class SRGANModel(BaseModel): if step % self.D_update_ratio == 0 and step > self.D_init_iters: if self.cri_pix: # pixel loss l_g_pix = self.l_pix_w * self.cri_pix(gen_img, pix) + l_g_pix_log = l_g_pix / self.l_pix_w l_g_total += l_g_pix if self.cri_fea: # feature loss real_fea = self.netF(pix).detach() fake_fea = self.netF(gen_img) l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea) + l_g_fea_log = l_g_fea / self.l_fea_w l_g_total += l_g_fea if _profile: @@ -282,9 +284,11 @@ class SRGANModel(BaseModel): l_g_gan = self.l_gan_w * ( self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) + self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2 + l_g_gan_log = l_g_gan / self.l_gan_w l_g_total += l_g_gan # Scale the loss down by the batch factor. + l_g_total_log = l_g_total l_g_total = l_g_total / self.mega_batch_factor with amp.scale_loss(l_g_total, self.optimizer_G, loss_id=0) as l_g_total_scaled: @@ -327,11 +331,13 @@ class SRGANModel(BaseModel): # real pred_d_real = self.netD(var_ref) l_d_real = self.cri_gan(pred_d_real, True) / self.mega_batch_factor + l_d_real_log = l_d_real * self.mega_batch_factor with amp.scale_loss(l_d_real, self.optimizer_D, loss_id=2) as l_d_real_scaled: l_d_real_scaled.backward() # fake pred_d_fake = self.netD(fake_H) l_d_fake = self.cri_gan(pred_d_fake, False) / self.mega_batch_factor + l_d_fake_log = l_d_fake * self.mega_batch_factor with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled: l_d_fake_scaled.backward() elif self.opt['train']['gan_type'] == 'ragan': @@ -349,6 +355,7 @@ class SRGANModel(BaseModel): _t = time() l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True) * 0.5 / self.mega_batch_factor + l_d_real_log = l_d_real * self.mega_batch_factor * 2 with amp.scale_loss(l_d_real, self.optimizer_D, loss_id=2) as l_d_real_scaled: l_d_real_scaled.backward() @@ -358,6 +365,7 @@ class SRGANModel(BaseModel): pred_d_fake = self.netD(fake_H) l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real.detach()), False) * 0.5 / self.mega_batch_factor + l_d_fake_log = l_d_fake * self.mega_batch_factor * 2 with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled: l_d_fake_scaled.backward() @@ -407,19 +415,18 @@ class SRGANModel(BaseModel): # Log metrics if step % self.D_update_ratio == 0 and step > self.D_init_iters: if self.cri_pix: - self.add_log_entry('l_g_pix', l_g_pix.item()) + self.add_log_entry('l_g_pix', l_g_pix_log.item()) if self.cri_fea: self.add_log_entry('feature_weight', self.l_fea_w) - self.add_log_entry('l_g_fea', l_g_fea.item()) + self.add_log_entry('l_g_fea', l_g_fea_log.item()) if self.l_gan_w > 0: - self.add_log_entry('l_g_gan', l_g_gan.item()) - self.add_log_entry('l_g_total', l_g_total.item() * self.mega_batch_factor) + self.add_log_entry('l_g_gan', l_g_gan_log.item()) + self.add_log_entry('l_g_total', l_g_total_log.item() * self.mega_batch_factor) if self.l_gan_w > 0 and step > self.G_warmup: - self.add_log_entry('l_d_real', l_d_real.item() * self.mega_batch_factor) - self.add_log_entry('l_d_fake', l_d_fake.item() * self.mega_batch_factor) + self.add_log_entry('l_d_real', l_d_real_log.item() * self.mega_batch_factor) + self.add_log_entry('l_d_fake', l_d_fake_log.item() * self.mega_batch_factor) self.add_log_entry('D_fake', torch.mean(pred_d_fake.detach())) self.add_log_entry('D_diff', torch.mean(pred_d_fake) - torch.mean(pred_d_real)) - self.add_log_entry('noise_theta', noise_theta) if step % self.corruptor_swapout_steps == 0 and step > 0: self.load_random_corruptor()