From 67139602f5b875bf5320bcbfe2c0f946960691fd Mon Sep 17 00:00:00 2001 From: James Betker Date: Tue, 19 May 2020 09:37:58 -0600 Subject: [PATCH] Test modifications Allows bifurcating large images put into the test pipeline This code is fixed and not dynamic. Needs some fixes. --- codes/data/LQ_dataset.py | 27 +++++++++++++-------------- codes/data/__init__.py | 2 +- codes/test.py | 29 +++-------------------------- 3 files changed, 17 insertions(+), 41 deletions(-) diff --git a/codes/data/LQ_dataset.py b/codes/data/LQ_dataset.py index 535c591a..29555dee 100644 --- a/codes/data/LQ_dataset.py +++ b/codes/data/LQ_dataset.py @@ -3,6 +3,9 @@ import lmdb import torch import torch.utils.data as data import data.util as util +import torchvision.transforms.functional as F +from PIL import Image +import os.path as osp class LQDataset(data.Dataset): @@ -30,24 +33,20 @@ class LQDataset(data.Dataset): def __getitem__(self, index): if self.data_type == 'lmdb' and self.LQ_env is None: self._init_lmdb() - LQ_path = None + actual_index = int(index / 2) + is_left = (index % 2) == 0 # get LQ image - LQ_path = self.paths_LQ[index] - resolution = [int(s) for s in self.sizes_LQ[index].split('_') - ] if self.data_type == 'lmdb' else None - img_LQ = util.read_img(self.LQ_env, LQ_path, resolution) - H, W, C = img_LQ.shape + LQ_path = self.paths_LQ[actual_index] + img_LQ = Image.open(LQ_path) + left = 0 if is_left else 2000 + img_LQ = F.crop(img_LQ, 74, left + 74, 1900, 1900) + img_LQ = F.to_tensor(img_LQ) - if self.opt['color']: # change color space if necessary - img_LQ = util.channel_convert(C, self.opt['color'], [img_LQ])[0] - - # BGR to RGB, HWC to CHW, numpy to tensor - if img_LQ.shape[2] == 3: - img_LQ = img_LQ[:, :, [2, 1, 0]] - img_LQ = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LQ, (2, 0, 1)))).float() + img_name = osp.splitext(osp.basename(LQ_path))[0] + LQ_path = LQ_path.replace(img_name, img_name + "_%i" % (index % 2)) return {'LQ': img_LQ, 'LQ_path': LQ_path} def __len__(self): - return len(self.paths_LQ) + return len(self.paths_LQ) * 2 diff --git a/codes/data/__init__.py b/codes/data/__init__.py index 97157006..e4aa847c 100644 --- a/codes/data/__init__.py +++ b/codes/data/__init__.py @@ -22,7 +22,7 @@ def create_dataloader(dataset, dataset_opt, opt=None, sampler=None): pin_memory=False) else: batch_size = dataset_opt['batch_size'] or 1 - return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=int(batch_size/2), + return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=max(int(batch_size/2), 1), pin_memory=False) diff --git a/codes/test.py b/codes/test.py index 8cea6476..edc70045 100644 --- a/codes/test.py +++ b/codes/test.py @@ -10,12 +10,14 @@ from data.util import bgr2ycbcr from data import create_dataset, create_dataloader from models import create_model from tqdm import tqdm +import torch if __name__ == "__main__": #### options + torch.backends.cudnn.benchmark = True want_just_images = True parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to options YMAL file.', default='options/test/test_vix_corrupt.yml') + parser.add_argument('-opt', type=str, help='Path to options YMAL file.', default='../options/use_vrp_upsample.yml') opt = option.parse(parser.parse_args().opt, is_train=False) opt = option.dict_to_nonedict(opt) @@ -76,31 +78,6 @@ if __name__ == "__main__": if want_just_images: continue - # calculate PSNR and SSIM - if need_GT: - gt_img = util.tensor2img(visuals['GT']) - sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale']) - psnr = util.calculate_psnr(sr_img, gt_img) - ssim = util.calculate_ssim(sr_img, gt_img) - test_results['psnr'].append(psnr) - test_results['ssim'].append(ssim) - - if gt_img.shape[2] == 3: # RGB image - sr_img_y = bgr2ycbcr(sr_img / 255., only_y=True) - gt_img_y = bgr2ycbcr(gt_img / 255., only_y=True) - - psnr_y = util.calculate_psnr(sr_img_y * 255, gt_img_y * 255) - ssim_y = util.calculate_ssim(sr_img_y * 255, gt_img_y * 255) - test_results['psnr_y'].append(psnr_y) - test_results['ssim_y'].append(ssim_y) - logger.info( - '{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}; PSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}.'. - format(img_name, psnr, ssim, psnr_y, ssim_y)) - else: - logger.info('{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}.'.format(img_name, psnr, ssim)) - else: - logger.info(img_name) - if not want_just_images and need_GT: # metrics # Average PSNR/SSIM results ave_psnr = sum(test_results['psnr']) / len(test_results['psnr'])