Test modifications

Allows bifurcating large images put into the test pipeline

This code is fixed and not dynamic. Needs some fixes.
This commit is contained in:
James Betker 2020-05-19 09:37:58 -06:00
parent 6400607fc5
commit 67139602f5
3 changed files with 17 additions and 41 deletions

View File

@ -3,6 +3,9 @@ import lmdb
import torch import torch
import torch.utils.data as data import torch.utils.data as data
import data.util as util import data.util as util
import torchvision.transforms.functional as F
from PIL import Image
import os.path as osp
class LQDataset(data.Dataset): class LQDataset(data.Dataset):
@ -30,24 +33,20 @@ class LQDataset(data.Dataset):
def __getitem__(self, index): def __getitem__(self, index):
if self.data_type == 'lmdb' and self.LQ_env is None: if self.data_type == 'lmdb' and self.LQ_env is None:
self._init_lmdb() self._init_lmdb()
LQ_path = None actual_index = int(index / 2)
is_left = (index % 2) == 0
# get LQ image # get LQ image
LQ_path = self.paths_LQ[index] LQ_path = self.paths_LQ[actual_index]
resolution = [int(s) for s in self.sizes_LQ[index].split('_') img_LQ = Image.open(LQ_path)
] if self.data_type == 'lmdb' else None left = 0 if is_left else 2000
img_LQ = util.read_img(self.LQ_env, LQ_path, resolution) img_LQ = F.crop(img_LQ, 74, left + 74, 1900, 1900)
H, W, C = img_LQ.shape img_LQ = F.to_tensor(img_LQ)
if self.opt['color']: # change color space if necessary img_name = osp.splitext(osp.basename(LQ_path))[0]
img_LQ = util.channel_convert(C, self.opt['color'], [img_LQ])[0] LQ_path = LQ_path.replace(img_name, img_name + "_%i" % (index % 2))
# BGR to RGB, HWC to CHW, numpy to tensor
if img_LQ.shape[2] == 3:
img_LQ = img_LQ[:, :, [2, 1, 0]]
img_LQ = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LQ, (2, 0, 1)))).float()
return {'LQ': img_LQ, 'LQ_path': LQ_path} return {'LQ': img_LQ, 'LQ_path': LQ_path}
def __len__(self): def __len__(self):
return len(self.paths_LQ) return len(self.paths_LQ) * 2

View File

@ -22,7 +22,7 @@ def create_dataloader(dataset, dataset_opt, opt=None, sampler=None):
pin_memory=False) pin_memory=False)
else: else:
batch_size = dataset_opt['batch_size'] or 1 batch_size = dataset_opt['batch_size'] or 1
return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=int(batch_size/2), return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=max(int(batch_size/2), 1),
pin_memory=False) pin_memory=False)

View File

@ -10,12 +10,14 @@ from data.util import bgr2ycbcr
from data import create_dataset, create_dataloader from data import create_dataset, create_dataloader
from models import create_model from models import create_model
from tqdm import tqdm from tqdm import tqdm
import torch
if __name__ == "__main__": if __name__ == "__main__":
#### options #### options
torch.backends.cudnn.benchmark = True
want_just_images = True want_just_images = True
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to options YMAL file.', default='options/test/test_vix_corrupt.yml') parser.add_argument('-opt', type=str, help='Path to options YMAL file.', default='../options/use_vrp_upsample.yml')
opt = option.parse(parser.parse_args().opt, is_train=False) opt = option.parse(parser.parse_args().opt, is_train=False)
opt = option.dict_to_nonedict(opt) opt = option.dict_to_nonedict(opt)
@ -76,31 +78,6 @@ if __name__ == "__main__":
if want_just_images: if want_just_images:
continue continue
# calculate PSNR and SSIM
if need_GT:
gt_img = util.tensor2img(visuals['GT'])
sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale'])
psnr = util.calculate_psnr(sr_img, gt_img)
ssim = util.calculate_ssim(sr_img, gt_img)
test_results['psnr'].append(psnr)
test_results['ssim'].append(ssim)
if gt_img.shape[2] == 3: # RGB image
sr_img_y = bgr2ycbcr(sr_img / 255., only_y=True)
gt_img_y = bgr2ycbcr(gt_img / 255., only_y=True)
psnr_y = util.calculate_psnr(sr_img_y * 255, gt_img_y * 255)
ssim_y = util.calculate_ssim(sr_img_y * 255, gt_img_y * 255)
test_results['psnr_y'].append(psnr_y)
test_results['ssim_y'].append(ssim_y)
logger.info(
'{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}; PSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}.'.
format(img_name, psnr, ssim, psnr_y, ssim_y))
else:
logger.info('{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}.'.format(img_name, psnr, ssim))
else:
logger.info(img_name)
if not want_just_images and need_GT: # metrics if not want_just_images and need_GT: # metrics
# Average PSNR/SSIM results # Average PSNR/SSIM results
ave_psnr = sum(test_results['psnr']) / len(test_results['psnr']) ave_psnr = sum(test_results['psnr']) / len(test_results['psnr'])