From 726d1913acc727a2a2d3142ddb692f6c0b82a2d0 Mon Sep 17 00:00:00 2001 From: James Betker <jbetker@gmail.com> Date: Tue, 2 Jun 2020 08:41:22 -0600 Subject: [PATCH] Allow validating in batches, remove val size limit --- codes/models/SRGAN_model.py | 6 +- codes/models/SR_model.py | 6 +- codes/train.py | 126 +++++++----------------------------- 3 files changed, 28 insertions(+), 110 deletions(-) diff --git a/codes/models/SRGAN_model.py b/codes/models/SRGAN_model.py index fa1dea4d..58f34aa9 100644 --- a/codes/models/SRGAN_model.py +++ b/codes/models/SRGAN_model.py @@ -448,13 +448,13 @@ class SRGANModel(BaseModel): def get_current_visuals(self, need_GT=True): out_dict = OrderedDict() - out_dict['LQ'] = self.var_L[0].detach()[0].float().cpu() + out_dict['LQ'] = self.var_L[0].detach().float().cpu() gen_batch = self.fake_GenOut[0] if isinstance(gen_batch, tuple): gen_batch = gen_batch[0] - out_dict['rlt'] = gen_batch.detach()[0].float().cpu() + out_dict['rlt'] = gen_batch.detach().float().cpu() if need_GT: - out_dict['GT'] = self.var_H[0].detach()[0].float().cpu() + out_dict['GT'] = self.var_H[0].detach().float().cpu() return out_dict def print_network(self): diff --git a/codes/models/SR_model.py b/codes/models/SR_model.py index bf46ea3f..02fc7a7c 100644 --- a/codes/models/SR_model.py +++ b/codes/models/SR_model.py @@ -144,10 +144,10 @@ class SRModel(BaseModel): def get_current_visuals(self, need_GT=True): out_dict = OrderedDict() - out_dict['LQ'] = self.var_L.detach()[0].float().cpu() - out_dict['rlt'] = self.fake_H.detach()[0].float().cpu() + out_dict['LQ'] = self.var_L.detach().float().cpu() + out_dict['rlt'] = self.fake_H.detach().float().cpu() if need_GT: - out_dict['GT'] = self.real_H.detach()[0].float().cpu() + out_dict['GT'] = self.real_H.detach().float().cpu() return out_dict def print_network(self): diff --git a/codes/train.py b/codes/train.py index c2173eca..d93358e4 100644 --- a/codes/train.py +++ b/codes/train.py @@ -30,7 +30,7 @@ def init_dist(backend='nccl', **kwargs): def main(): #### options parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imset_pre_rrdb.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_cifar_rrdb.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) @@ -204,38 +204,38 @@ def main(): if opt['datasets'].get('val', None) and current_step % opt['train']['val_freq'] == 0: if opt['model'] in ['sr', 'srgan', 'corruptgan'] and rank <= 0: # image restoration validation model.force_restore_swapout() + val_batch_sz = 1 if 'batch_size' not in opt['datasets']['val'].keys() else opt['datasets']['val']['batch_size'] # does not support multi-GPU validation - pbar = util.ProgressBar(len(val_loader)) + pbar = util.ProgressBar(len(val_loader) * val_batch_sz) avg_psnr = 0. idx = 0 colab_imgs_to_copy = [] for val_data in val_loader: idx += 1 - if idx >= 20: - break - img_name = os.path.splitext(os.path.basename(val_data['LQ_path'][0]))[0] - img_dir = os.path.join(opt['path']['val_images'], img_name) - util.mkdir(img_dir) + for b in range(len(val_data['LQ_path'])): + img_name = os.path.splitext(os.path.basename(val_data['LQ_path'][b]))[0] + img_dir = os.path.join(opt['path']['val_images'], img_name) + util.mkdir(img_dir) - model.feed_data(val_data) - model.test() + model.feed_data(val_data) + model.test() - visuals = model.get_current_visuals() + visuals = model.get_current_visuals() - sr_img = util.tensor2img(visuals['rlt']) # uint8 - gt_img = util.tensor2img(visuals['GT']) # uint8 + sr_img = util.tensor2img(visuals['rlt'][b]) # uint8 + gt_img = util.tensor2img(visuals['GT'][b]) # uint8 - # Save SR images for reference - img_base_name = '{:s}_{:d}.png'.format(img_name, current_step) - save_img_path = os.path.join(img_dir, img_base_name) - util.save_img(sr_img, save_img_path) - if colab_mode: - colab_imgs_to_copy.append(save_img_path) + # Save SR images for reference + img_base_name = '{:s}_{:d}.png'.format(img_name, current_step) + save_img_path = os.path.join(img_dir, img_base_name) + util.save_img(sr_img, save_img_path) + if colab_mode: + colab_imgs_to_copy.append(save_img_path) - # calculate PSNR - sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale']) - avg_psnr += util.calculate_psnr(sr_img, gt_img) - pbar.update('Test {}'.format(img_name)) + # calculate PSNR + sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale']) + avg_psnr += util.calculate_psnr(sr_img, gt_img) + pbar.update('Test {}'.format(img_name)) if colab_mode: util.copy_files_to_server(opt['ssh_server'], opt['ssh_username'], opt['ssh_password'], @@ -249,88 +249,6 @@ def main(): # tensorboard logger if opt['use_tb_logger'] and 'debug' not in opt['name']: tb_logger.add_scalar('psnr', avg_psnr, current_step) - else: # video restoration validation - if opt['dist']: - # multi-GPU testing - psnr_rlt = {} # with border and center frames - if rank == 0: - pbar = util.ProgressBar(len(val_set)) - for idx in range(rank, len(val_set), world_size): - val_data = val_set[idx] - val_data['LQs'].unsqueeze_(0) - val_data['GT'].unsqueeze_(0) - folder = val_data['folder'] - idx_d, max_idx = val_data['idx'].split('/') - idx_d, max_idx = int(idx_d), int(max_idx) - if psnr_rlt.get(folder, None) is None: - psnr_rlt[folder] = torch.zeros(max_idx, dtype=torch.float32, - device='cuda') - # tmp = torch.zeros(max_idx, dtype=torch.float32, device='cuda') - model.feed_data(val_data) - model.test() - visuals = model.get_current_visuals() - rlt_img = util.tensor2img(visuals['rlt']) # uint8 - gt_img = util.tensor2img(visuals['GT']) # uint8 - # calculate PSNR - psnr_rlt[folder][idx_d] = util.calculate_psnr(rlt_img, gt_img) - - if rank == 0: - for _ in range(world_size): - pbar.update('Test {} - {}/{}'.format(folder, idx_d, max_idx)) - # # collect data - for _, v in psnr_rlt.items(): - dist.reduce(v, 0) - dist.barrier() - - if rank == 0: - psnr_rlt_avg = {} - psnr_total_avg = 0. - for k, v in psnr_rlt.items(): - psnr_rlt_avg[k] = torch.mean(v).cpu().item() - psnr_total_avg += psnr_rlt_avg[k] - psnr_total_avg /= len(psnr_rlt) - log_s = '# Validation # PSNR: {:.4e}:'.format(psnr_total_avg) - for k, v in psnr_rlt_avg.items(): - log_s += ' {}: {:.4e}'.format(k, v) - logger.info(log_s) - if opt['use_tb_logger'] and 'debug' not in opt['name']: - tb_logger.add_scalar('psnr_avg', psnr_total_avg, current_step) - for k, v in psnr_rlt_avg.items(): - tb_logger.add_scalar(k, v, current_step) - else: - pbar = util.ProgressBar(len(val_loader)) - psnr_rlt = {} # with border and center frames - psnr_rlt_avg = {} - psnr_total_avg = 0. - for val_data in val_loader: - folder = val_data['folder'][0] - idx_d = val_data['idx'].item() - # border = val_data['border'].item() - if psnr_rlt.get(folder, None) is None: - psnr_rlt[folder] = [] - - model.feed_data(val_data) - model.test() - visuals = model.get_current_visuals() - rlt_img = util.tensor2img(visuals['rlt']) # uint8 - gt_img = util.tensor2img(visuals['GT']) # uint8 - - # calculate PSNR - psnr = util.calculate_psnr(rlt_img, gt_img) - psnr_rlt[folder].append(psnr) - pbar.update('Test {} - {}'.format(folder, idx_d)) - for k, v in psnr_rlt.items(): - psnr_rlt_avg[k] = sum(v) / len(v) - psnr_total_avg += psnr_rlt_avg[k] - psnr_total_avg /= len(psnr_rlt) - log_s = '# Validation # PSNR: {:.4e}:'.format(psnr_total_avg) - for k, v in psnr_rlt_avg.items(): - log_s += ' {}: {:.4e}'.format(k, v) - logger.info(log_s) - if opt['use_tb_logger'] and 'debug' not in opt['name']: - tb_logger.add_scalar('psnr_avg', psnr_total_avg, current_step) - for k, v in psnr_rlt_avg.items(): - tb_logger.add_scalar(k, v, current_step) #### save models and training states if current_step % opt['logger']['save_checkpoint_freq'] == 0: