Update log metrics
They should now be universal regardless of job configuration
This commit is contained in:
parent
8a4eb8241d
commit
4305be97b4
|
@ -256,11 +256,13 @@ class SRGANModel(BaseModel):
|
|||
if step % self.D_update_ratio == 0 and step > self.D_init_iters:
|
||||
if self.cri_pix: # pixel loss
|
||||
l_g_pix = self.l_pix_w * self.cri_pix(gen_img, pix)
|
||||
l_g_pix_log = l_g_pix / self.l_pix_w
|
||||
l_g_total += l_g_pix
|
||||
if self.cri_fea: # feature loss
|
||||
real_fea = self.netF(pix).detach()
|
||||
fake_fea = self.netF(gen_img)
|
||||
l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea)
|
||||
l_g_fea_log = l_g_fea / self.l_fea_w
|
||||
l_g_total += l_g_fea
|
||||
|
||||
if _profile:
|
||||
|
@ -282,9 +284,11 @@ class SRGANModel(BaseModel):
|
|||
l_g_gan = self.l_gan_w * (
|
||||
self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) +
|
||||
self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2
|
||||
l_g_gan_log = l_g_gan / self.l_gan_w
|
||||
l_g_total += l_g_gan
|
||||
|
||||
# Scale the loss down by the batch factor.
|
||||
l_g_total_log = l_g_total
|
||||
l_g_total = l_g_total / self.mega_batch_factor
|
||||
|
||||
with amp.scale_loss(l_g_total, self.optimizer_G, loss_id=0) as l_g_total_scaled:
|
||||
|
@ -327,11 +331,13 @@ class SRGANModel(BaseModel):
|
|||
# real
|
||||
pred_d_real = self.netD(var_ref)
|
||||
l_d_real = self.cri_gan(pred_d_real, True) / self.mega_batch_factor
|
||||
l_d_real_log = l_d_real * self.mega_batch_factor
|
||||
with amp.scale_loss(l_d_real, self.optimizer_D, loss_id=2) as l_d_real_scaled:
|
||||
l_d_real_scaled.backward()
|
||||
# fake
|
||||
pred_d_fake = self.netD(fake_H)
|
||||
l_d_fake = self.cri_gan(pred_d_fake, False) / self.mega_batch_factor
|
||||
l_d_fake_log = l_d_fake * self.mega_batch_factor
|
||||
with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled:
|
||||
l_d_fake_scaled.backward()
|
||||
elif self.opt['train']['gan_type'] == 'ragan':
|
||||
|
@ -349,6 +355,7 @@ class SRGANModel(BaseModel):
|
|||
_t = time()
|
||||
|
||||
l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True) * 0.5 / self.mega_batch_factor
|
||||
l_d_real_log = l_d_real * self.mega_batch_factor * 2
|
||||
with amp.scale_loss(l_d_real, self.optimizer_D, loss_id=2) as l_d_real_scaled:
|
||||
l_d_real_scaled.backward()
|
||||
|
||||
|
@ -358,6 +365,7 @@ class SRGANModel(BaseModel):
|
|||
|
||||
pred_d_fake = self.netD(fake_H)
|
||||
l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real.detach()), False) * 0.5 / self.mega_batch_factor
|
||||
l_d_fake_log = l_d_fake * self.mega_batch_factor * 2
|
||||
with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled:
|
||||
l_d_fake_scaled.backward()
|
||||
|
||||
|
@ -407,19 +415,18 @@ class SRGANModel(BaseModel):
|
|||
# Log metrics
|
||||
if step % self.D_update_ratio == 0 and step > self.D_init_iters:
|
||||
if self.cri_pix:
|
||||
self.add_log_entry('l_g_pix', l_g_pix.item())
|
||||
self.add_log_entry('l_g_pix', l_g_pix_log.item())
|
||||
if self.cri_fea:
|
||||
self.add_log_entry('feature_weight', self.l_fea_w)
|
||||
self.add_log_entry('l_g_fea', l_g_fea.item())
|
||||
self.add_log_entry('l_g_fea', l_g_fea_log.item())
|
||||
if self.l_gan_w > 0:
|
||||
self.add_log_entry('l_g_gan', l_g_gan.item())
|
||||
self.add_log_entry('l_g_total', l_g_total.item() * self.mega_batch_factor)
|
||||
self.add_log_entry('l_g_gan', l_g_gan_log.item())
|
||||
self.add_log_entry('l_g_total', l_g_total_log.item() * self.mega_batch_factor)
|
||||
if self.l_gan_w > 0 and step > self.G_warmup:
|
||||
self.add_log_entry('l_d_real', l_d_real.item() * self.mega_batch_factor)
|
||||
self.add_log_entry('l_d_fake', l_d_fake.item() * self.mega_batch_factor)
|
||||
self.add_log_entry('l_d_real', l_d_real_log.item() * self.mega_batch_factor)
|
||||
self.add_log_entry('l_d_fake', l_d_fake_log.item() * self.mega_batch_factor)
|
||||
self.add_log_entry('D_fake', torch.mean(pred_d_fake.detach()))
|
||||
self.add_log_entry('D_diff', torch.mean(pred_d_fake) - torch.mean(pred_d_real))
|
||||
self.add_log_entry('noise_theta', noise_theta)
|
||||
|
||||
if step % self.corruptor_swapout_steps == 0 and step > 0:
|
||||
self.load_random_corruptor()
|
||||
|
|
Loading…
Reference in New Issue
Block a user