From da4335c25e0d979c0ede4401334861de720f0557 Mon Sep 17 00:00:00 2001 From: James Betker Date: Fri, 3 Jul 2020 15:18:57 -0600 Subject: [PATCH] Add a feature-based validation test --- codes/models/SRGAN_model.py | 8 ++++++++ codes/train.py | 20 +++++++++++++------- 2 files changed, 21 insertions(+), 7 deletions(-) diff --git a/codes/models/SRGAN_model.py b/codes/models/SRGAN_model.py index 4eb6582b..7e4fd2e7 100644 --- a/codes/models/SRGAN_model.py +++ b/codes/models/SRGAN_model.py @@ -449,6 +449,14 @@ class SRGANModel(BaseModel): # Just a note: this intentionally includes the swap model in the list of possibilities. return previous_models[random.randint(0, len(previous_models)-1)] + def compute_fea_loss(self, real, fake): + with torch.no_grad(): + real = real.unsqueeze(dim=0) + fake = fake.unsqueeze(dim=0) + real_fea = self.netF(real).detach() + fake_fea = self.netF(fake) + return self.cri_fea(fake_fea, real_fea).item() + # Called before verification/checkpoint to ensure we're using the real models and not a swapout variant. def force_restore_swapout(self): if self.swapout_D_duration > 0: diff --git a/codes/train.py b/codes/train.py index 8c7c0dc8..14de62cd 100644 --- a/codes/train.py +++ b/codes/train.py @@ -162,7 +162,7 @@ def main(): current_step = resume_state['iter'] model.resume_training(resume_state) # handle optimizers and schedulers else: - current_step = 0 + current_step = -1 start_epoch = 0 #### training @@ -221,6 +221,7 @@ def main(): # does not support multi-GPU validation pbar = util.ProgressBar(len(val_loader) * val_batch_sz) avg_psnr = 0. + avg_fea_loss = 0. idx = 0 colab_imgs_to_copy = [] for val_data in val_loader: @@ -245,10 +246,13 @@ def main(): if colab_mode: colab_imgs_to_copy.append(save_img_path) - # calculate PSNR - sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale']) - avg_psnr += util.calculate_psnr(sr_img, gt_img) - pbar.update('Test {}'.format(img_name)) + # calculate PSNR (Naw - don't do that. PSNR sucks) + #sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale']) + #avg_psnr += util.calculate_psnr(sr_img, gt_img) + #pbar.update('Test {}'.format(img_name)) + + # calculate fea loss + avg_fea_loss += model.compute_fea_loss(visuals['rlt'][b], visuals['GT'][b]) if colab_mode: util.copy_files_to_server(opt['ssh_server'], opt['ssh_username'], opt['ssh_password'], @@ -256,12 +260,14 @@ def main(): os.path.join(opt['remote_path'], 'val_images', img_base_name)) avg_psnr = avg_psnr / idx + avg_fea_loss = avg_fea_loss / idx # log - logger.info('# Validation # PSNR: {:.4e}'.format(avg_psnr)) + logger.info('# Validation # PSNR: {:.4e} Fea: {:.4e}'.format(avg_psnr, avg_fea_loss)) # tensorboard logger if opt['use_tb_logger'] and 'debug' not in opt['name']: - tb_logger.add_scalar('psnr', avg_psnr, current_step) + tb_logger.add_scalar('val_psnr', avg_psnr, current_step) + tb_logger.add_scalar('val_fea', avg_fea_loss, current_step) #### save models and training states if current_step % opt['logger']['save_checkpoint_freq'] == 0: