From 726d1913acc727a2a2d3142ddb692f6c0b82a2d0 Mon Sep 17 00:00:00 2001
From: James Betker <jbetker@gmail.com>
Date: Tue, 2 Jun 2020 08:41:22 -0600
Subject: [PATCH] Allow validating in batches, remove val size limit

---
 codes/models/SRGAN_model.py |   6 +-
 codes/models/SR_model.py    |   6 +-
 codes/train.py              | 126 +++++++-----------------------------
 3 files changed, 28 insertions(+), 110 deletions(-)

diff --git a/codes/models/SRGAN_model.py b/codes/models/SRGAN_model.py
index fa1dea4d..58f34aa9 100644
--- a/codes/models/SRGAN_model.py
+++ b/codes/models/SRGAN_model.py
@@ -448,13 +448,13 @@ class SRGANModel(BaseModel):
 
     def get_current_visuals(self, need_GT=True):
         out_dict = OrderedDict()
-        out_dict['LQ'] = self.var_L[0].detach()[0].float().cpu()
+        out_dict['LQ'] = self.var_L[0].detach().float().cpu()
         gen_batch = self.fake_GenOut[0]
         if isinstance(gen_batch, tuple):
             gen_batch = gen_batch[0]
-        out_dict['rlt'] = gen_batch.detach()[0].float().cpu()
+        out_dict['rlt'] = gen_batch.detach().float().cpu()
         if need_GT:
-            out_dict['GT'] = self.var_H[0].detach()[0].float().cpu()
+            out_dict['GT'] = self.var_H[0].detach().float().cpu()
         return out_dict
 
     def print_network(self):
diff --git a/codes/models/SR_model.py b/codes/models/SR_model.py
index bf46ea3f..02fc7a7c 100644
--- a/codes/models/SR_model.py
+++ b/codes/models/SR_model.py
@@ -144,10 +144,10 @@ class SRModel(BaseModel):
 
     def get_current_visuals(self, need_GT=True):
         out_dict = OrderedDict()
-        out_dict['LQ'] = self.var_L.detach()[0].float().cpu()
-        out_dict['rlt'] = self.fake_H.detach()[0].float().cpu()
+        out_dict['LQ'] = self.var_L.detach().float().cpu()
+        out_dict['rlt'] = self.fake_H.detach().float().cpu()
         if need_GT:
-            out_dict['GT'] = self.real_H.detach()[0].float().cpu()
+            out_dict['GT'] = self.real_H.detach().float().cpu()
         return out_dict
 
     def print_network(self):
diff --git a/codes/train.py b/codes/train.py
index c2173eca..d93358e4 100644
--- a/codes/train.py
+++ b/codes/train.py
@@ -30,7 +30,7 @@ def init_dist(backend='nccl', **kwargs):
 def main():
     #### options
     parser = argparse.ArgumentParser()
-    parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imset_pre_rrdb.yml')
+    parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_cifar_rrdb.yml')
     parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
                         help='job launcher')
     parser.add_argument('--local_rank', type=int, default=0)
@@ -204,38 +204,38 @@ def main():
             if opt['datasets'].get('val', None) and current_step % opt['train']['val_freq'] == 0:
                 if opt['model'] in ['sr', 'srgan', 'corruptgan'] and rank <= 0:  # image restoration validation
                     model.force_restore_swapout()
+                    val_batch_sz = 1 if 'batch_size' not in opt['datasets']['val'].keys() else opt['datasets']['val']['batch_size']
                     # does not support multi-GPU validation
-                    pbar = util.ProgressBar(len(val_loader))
+                    pbar = util.ProgressBar(len(val_loader) * val_batch_sz)
                     avg_psnr = 0.
                     idx = 0
                     colab_imgs_to_copy = []
                     for val_data in val_loader:
                         idx += 1
-                        if idx >= 20:
-                            break
-                        img_name = os.path.splitext(os.path.basename(val_data['LQ_path'][0]))[0]
-                        img_dir = os.path.join(opt['path']['val_images'], img_name)
-                        util.mkdir(img_dir)
+                        for b in range(len(val_data['LQ_path'])):
+                            img_name = os.path.splitext(os.path.basename(val_data['LQ_path'][b]))[0]
+                            img_dir = os.path.join(opt['path']['val_images'], img_name)
+                            util.mkdir(img_dir)
 
-                        model.feed_data(val_data)
-                        model.test()
+                            model.feed_data(val_data)
+                            model.test()
 
-                        visuals = model.get_current_visuals()
+                            visuals = model.get_current_visuals()
 
-                        sr_img = util.tensor2img(visuals['rlt'])  # uint8
-                        gt_img = util.tensor2img(visuals['GT'])  # uint8
+                            sr_img = util.tensor2img(visuals['rlt'][b])  # uint8
+                            gt_img = util.tensor2img(visuals['GT'][b])  # uint8
 
-                        # Save SR images for reference
-                        img_base_name = '{:s}_{:d}.png'.format(img_name, current_step)
-                        save_img_path = os.path.join(img_dir, img_base_name)
-                        util.save_img(sr_img, save_img_path)
-                        if colab_mode:
-                            colab_imgs_to_copy.append(save_img_path)
+                            # Save SR images for reference
+                            img_base_name = '{:s}_{:d}.png'.format(img_name, current_step)
+                            save_img_path = os.path.join(img_dir, img_base_name)
+                            util.save_img(sr_img, save_img_path)
+                            if colab_mode:
+                                colab_imgs_to_copy.append(save_img_path)
 
-                        # calculate PSNR
-                        sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale'])
-                        avg_psnr += util.calculate_psnr(sr_img, gt_img)
-                        pbar.update('Test {}'.format(img_name))
+                            # calculate PSNR
+                            sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale'])
+                            avg_psnr += util.calculate_psnr(sr_img, gt_img)
+                            pbar.update('Test {}'.format(img_name))
 
                     if colab_mode:
                         util.copy_files_to_server(opt['ssh_server'], opt['ssh_username'], opt['ssh_password'],
@@ -249,88 +249,6 @@ def main():
                     # tensorboard logger
                     if opt['use_tb_logger'] and 'debug' not in opt['name']:
                         tb_logger.add_scalar('psnr', avg_psnr, current_step)
-                else:  # video restoration validation
-                    if opt['dist']:
-                        # multi-GPU testing
-                        psnr_rlt = {}  # with border and center frames
-                        if rank == 0:
-                            pbar = util.ProgressBar(len(val_set))
-                        for idx in range(rank, len(val_set), world_size):
-                            val_data = val_set[idx]
-                            val_data['LQs'].unsqueeze_(0)
-                            val_data['GT'].unsqueeze_(0)
-                            folder = val_data['folder']
-                            idx_d, max_idx = val_data['idx'].split('/')
-                            idx_d, max_idx = int(idx_d), int(max_idx)
-                            if psnr_rlt.get(folder, None) is None:
-                                psnr_rlt[folder] = torch.zeros(max_idx, dtype=torch.float32,
-                                                               device='cuda')
-                            # tmp = torch.zeros(max_idx, dtype=torch.float32, device='cuda')
-                            model.feed_data(val_data)
-                            model.test()
-                            visuals = model.get_current_visuals()
-                            rlt_img = util.tensor2img(visuals['rlt'])  # uint8
-                            gt_img = util.tensor2img(visuals['GT'])  # uint8
-                            # calculate PSNR
-                            psnr_rlt[folder][idx_d] = util.calculate_psnr(rlt_img, gt_img)
-
-                            if rank == 0:
-                                for _ in range(world_size):
-                                    pbar.update('Test {} - {}/{}'.format(folder, idx_d, max_idx))
-                        # # collect data
-                        for _, v in psnr_rlt.items():
-                            dist.reduce(v, 0)
-                        dist.barrier()
-
-                        if rank == 0:
-                            psnr_rlt_avg = {}
-                            psnr_total_avg = 0.
-                            for k, v in psnr_rlt.items():
-                                psnr_rlt_avg[k] = torch.mean(v).cpu().item()
-                                psnr_total_avg += psnr_rlt_avg[k]
-                            psnr_total_avg /= len(psnr_rlt)
-                            log_s = '# Validation # PSNR: {:.4e}:'.format(psnr_total_avg)
-                            for k, v in psnr_rlt_avg.items():
-                                log_s += ' {}: {:.4e}'.format(k, v)
-                            logger.info(log_s)
-                            if opt['use_tb_logger'] and 'debug' not in opt['name']:
-                                tb_logger.add_scalar('psnr_avg', psnr_total_avg, current_step)
-                                for k, v in psnr_rlt_avg.items():
-                                    tb_logger.add_scalar(k, v, current_step)
-                    else:
-                        pbar = util.ProgressBar(len(val_loader))
-                        psnr_rlt = {}  # with border and center frames
-                        psnr_rlt_avg = {}
-                        psnr_total_avg = 0.
-                        for val_data in val_loader:
-                            folder = val_data['folder'][0]
-                            idx_d = val_data['idx'].item()
-                            # border = val_data['border'].item()
-                            if psnr_rlt.get(folder, None) is None:
-                                psnr_rlt[folder] = []
-
-                            model.feed_data(val_data)
-                            model.test()
-                            visuals = model.get_current_visuals()
-                            rlt_img = util.tensor2img(visuals['rlt'])  # uint8
-                            gt_img = util.tensor2img(visuals['GT'])  # uint8
-
-                            # calculate PSNR
-                            psnr = util.calculate_psnr(rlt_img, gt_img)
-                            psnr_rlt[folder].append(psnr)
-                            pbar.update('Test {} - {}'.format(folder, idx_d))
-                        for k, v in psnr_rlt.items():
-                            psnr_rlt_avg[k] = sum(v) / len(v)
-                            psnr_total_avg += psnr_rlt_avg[k]
-                        psnr_total_avg /= len(psnr_rlt)
-                        log_s = '# Validation # PSNR: {:.4e}:'.format(psnr_total_avg)
-                        for k, v in psnr_rlt_avg.items():
-                            log_s += ' {}: {:.4e}'.format(k, v)
-                        logger.info(log_s)
-                        if opt['use_tb_logger'] and 'debug' not in opt['name']:
-                            tb_logger.add_scalar('psnr_avg', psnr_total_avg, current_step)
-                            for k, v in psnr_rlt_avg.items():
-                                tb_logger.add_scalar(k, v, current_step)
 
             #### save models and training states
             if current_step % opt['logger']['save_checkpoint_freq'] == 0: