72 lines
2.5 KiB
Python
72 lines
2.5 KiB
Python
import numpy as np
|
|
import lmdb
|
|
import torch
|
|
import torch.utils.data as data
|
|
import data.util as util
|
|
import torchvision.transforms.functional as F
|
|
from PIL import Image
|
|
import os.path as osp
|
|
import cv2
|
|
|
|
|
|
class LQDataset(data.Dataset):
|
|
'''Read LQ images only in the test phase.'''
|
|
|
|
def __init__(self, opt):
|
|
super(LQDataset, self).__init__()
|
|
self.opt = opt
|
|
self.data_type = self.opt['data_type']
|
|
if 'start_at' in self.opt.keys():
|
|
self.start_at = self.opt['start_at']
|
|
else:
|
|
self.start_at = 0
|
|
self.vertical_splits = self.opt['vertical_splits']
|
|
self.paths_LQ, self.paths_GT = None, None
|
|
self.LQ_env = None # environment for lmdb
|
|
self.force_multiple = self.opt['force_multiple'] if 'force_multiple' in self.opt.keys() else 1
|
|
|
|
self.paths_LQ, self.sizes_LQ = util.get_image_paths(self.data_type, opt['dataroot_LQ'])
|
|
self.paths_LQ = self.paths_LQ[self.start_at:]
|
|
assert self.paths_LQ, 'Error: LQ paths are empty.'
|
|
|
|
def _init_lmdb(self):
|
|
self.LQ_env = lmdb.open(self.opt['dataroot_LQ'], readonly=True, lock=False, readahead=False,
|
|
meminit=False)
|
|
|
|
def __getitem__(self, index):
|
|
if self.data_type == 'lmdb' and self.LQ_env is None:
|
|
self._init_lmdb()
|
|
if self.vertical_splits > 0:
|
|
actual_index = int(index / self.vertical_splits)
|
|
else:
|
|
actual_index = index
|
|
|
|
# get LQ image
|
|
LQ_path = self.paths_LQ[actual_index]
|
|
img_LQ = Image.open(LQ_path)
|
|
if self.vertical_splits > 0:
|
|
w, h = img_LQ.size
|
|
split_index = (index % self.vertical_splits)
|
|
w_per_split = int(w / self.vertical_splits)
|
|
left = w_per_split * split_index
|
|
img_LQ = F.crop(img_LQ, 0, left, h, w_per_split)
|
|
|
|
# Enforce force_resize constraints.
|
|
h, w = img_LQ.size
|
|
if h % self.force_multiple != 0 or w % self.force_multiple != 0:
|
|
h, w = (w - w % self.force_multiple), (h - h % self.force_multiple)
|
|
img_LQ = img_LQ.resize((w, h))
|
|
|
|
img_LQ = F.to_tensor(img_LQ)
|
|
|
|
img_name = osp.splitext(osp.basename(LQ_path))[0]
|
|
LQ_path = LQ_path.replace(img_name, img_name + "_%i" % (index % self.vertical_splits))
|
|
|
|
return {'LQ': img_LQ, 'LQ_path': LQ_path}
|
|
|
|
def __len__(self):
|
|
if self.vertical_splits > 0:
|
|
return len(self.paths_LQ) * self.vertical_splits
|
|
else:
|
|
return len(self.paths_LQ)
|