Update log metrics

They should now be universal regardless of job configuration
This commit is contained in:
James Betker 2020-07-07 15:33:22 -06:00
parent 8a4eb8241d
commit 4305be97b4

View File

@ -256,11 +256,13 @@ class SRGANModel(BaseModel):
if step % self.D_update_ratio == 0 and step > self.D_init_iters: if step % self.D_update_ratio == 0 and step > self.D_init_iters:
if self.cri_pix: # pixel loss if self.cri_pix: # pixel loss
l_g_pix = self.l_pix_w * self.cri_pix(gen_img, pix) l_g_pix = self.l_pix_w * self.cri_pix(gen_img, pix)
l_g_pix_log = l_g_pix / self.l_pix_w
l_g_total += l_g_pix l_g_total += l_g_pix
if self.cri_fea: # feature loss if self.cri_fea: # feature loss
real_fea = self.netF(pix).detach() real_fea = self.netF(pix).detach()
fake_fea = self.netF(gen_img) fake_fea = self.netF(gen_img)
l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea) l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea)
l_g_fea_log = l_g_fea / self.l_fea_w
l_g_total += l_g_fea l_g_total += l_g_fea
if _profile: if _profile:
@ -282,9 +284,11 @@ class SRGANModel(BaseModel):
l_g_gan = self.l_gan_w * ( l_g_gan = self.l_gan_w * (
self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) + self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) +
self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2 self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2
l_g_gan_log = l_g_gan / self.l_gan_w
l_g_total += l_g_gan l_g_total += l_g_gan
# Scale the loss down by the batch factor. # Scale the loss down by the batch factor.
l_g_total_log = l_g_total
l_g_total = l_g_total / self.mega_batch_factor l_g_total = l_g_total / self.mega_batch_factor
with amp.scale_loss(l_g_total, self.optimizer_G, loss_id=0) as l_g_total_scaled: with amp.scale_loss(l_g_total, self.optimizer_G, loss_id=0) as l_g_total_scaled:
@ -327,11 +331,13 @@ class SRGANModel(BaseModel):
# real # real
pred_d_real = self.netD(var_ref) pred_d_real = self.netD(var_ref)
l_d_real = self.cri_gan(pred_d_real, True) / self.mega_batch_factor l_d_real = self.cri_gan(pred_d_real, True) / self.mega_batch_factor
l_d_real_log = l_d_real * self.mega_batch_factor
with amp.scale_loss(l_d_real, self.optimizer_D, loss_id=2) as l_d_real_scaled: with amp.scale_loss(l_d_real, self.optimizer_D, loss_id=2) as l_d_real_scaled:
l_d_real_scaled.backward() l_d_real_scaled.backward()
# fake # fake
pred_d_fake = self.netD(fake_H) pred_d_fake = self.netD(fake_H)
l_d_fake = self.cri_gan(pred_d_fake, False) / self.mega_batch_factor l_d_fake = self.cri_gan(pred_d_fake, False) / self.mega_batch_factor
l_d_fake_log = l_d_fake * self.mega_batch_factor
with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled: with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled:
l_d_fake_scaled.backward() l_d_fake_scaled.backward()
elif self.opt['train']['gan_type'] == 'ragan': elif self.opt['train']['gan_type'] == 'ragan':
@ -349,6 +355,7 @@ class SRGANModel(BaseModel):
_t = time() _t = time()
l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True) * 0.5 / self.mega_batch_factor l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True) * 0.5 / self.mega_batch_factor
l_d_real_log = l_d_real * self.mega_batch_factor * 2
with amp.scale_loss(l_d_real, self.optimizer_D, loss_id=2) as l_d_real_scaled: with amp.scale_loss(l_d_real, self.optimizer_D, loss_id=2) as l_d_real_scaled:
l_d_real_scaled.backward() l_d_real_scaled.backward()
@ -358,6 +365,7 @@ class SRGANModel(BaseModel):
pred_d_fake = self.netD(fake_H) pred_d_fake = self.netD(fake_H)
l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real.detach()), False) * 0.5 / self.mega_batch_factor l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real.detach()), False) * 0.5 / self.mega_batch_factor
l_d_fake_log = l_d_fake * self.mega_batch_factor * 2
with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled: with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled:
l_d_fake_scaled.backward() l_d_fake_scaled.backward()
@ -407,19 +415,18 @@ class SRGANModel(BaseModel):
# Log metrics # Log metrics
if step % self.D_update_ratio == 0 and step > self.D_init_iters: if step % self.D_update_ratio == 0 and step > self.D_init_iters:
if self.cri_pix: if self.cri_pix:
self.add_log_entry('l_g_pix', l_g_pix.item()) self.add_log_entry('l_g_pix', l_g_pix_log.item())
if self.cri_fea: if self.cri_fea:
self.add_log_entry('feature_weight', self.l_fea_w) self.add_log_entry('feature_weight', self.l_fea_w)
self.add_log_entry('l_g_fea', l_g_fea.item()) self.add_log_entry('l_g_fea', l_g_fea_log.item())
if self.l_gan_w > 0: if self.l_gan_w > 0:
self.add_log_entry('l_g_gan', l_g_gan.item()) self.add_log_entry('l_g_gan', l_g_gan_log.item())
self.add_log_entry('l_g_total', l_g_total.item() * self.mega_batch_factor) self.add_log_entry('l_g_total', l_g_total_log.item() * self.mega_batch_factor)
if self.l_gan_w > 0 and step > self.G_warmup: if self.l_gan_w > 0 and step > self.G_warmup:
self.add_log_entry('l_d_real', l_d_real.item() * self.mega_batch_factor) self.add_log_entry('l_d_real', l_d_real_log.item() * self.mega_batch_factor)
self.add_log_entry('l_d_fake', l_d_fake.item() * self.mega_batch_factor) self.add_log_entry('l_d_fake', l_d_fake_log.item() * self.mega_batch_factor)
self.add_log_entry('D_fake', torch.mean(pred_d_fake.detach())) self.add_log_entry('D_fake', torch.mean(pred_d_fake.detach()))
self.add_log_entry('D_diff', torch.mean(pred_d_fake) - torch.mean(pred_d_real)) self.add_log_entry('D_diff', torch.mean(pred_d_fake) - torch.mean(pred_d_real))
self.add_log_entry('noise_theta', noise_theta)
if step % self.corruptor_swapout_steps == 0 and step > 0: if step % self.corruptor_swapout_steps == 0 and step > 0:
self.load_random_corruptor() self.load_random_corruptor()