|
|
|
@ -19,12 +19,14 @@ class LQGTDataset(data.Dataset):
|
|
|
|
|
self.data_type = self.opt['data_type']
|
|
|
|
|
self.paths_LQ, self.paths_GT = None, None
|
|
|
|
|
self.sizes_LQ, self.sizes_GT = None, None
|
|
|
|
|
self.LQ_env, self.GT_env = None, None # environments for lmdb
|
|
|
|
|
self.paths_PIX, self.sizes_PIX = None, None
|
|
|
|
|
self.LQ_env, self.GT_env, self.PIX_env = None, None, None # environments for lmdb
|
|
|
|
|
|
|
|
|
|
self.paths_GT, self.sizes_GT = util.get_image_paths(self.data_type, opt['dataroot_GT'])
|
|
|
|
|
self.paths_LQ, self.sizes_LQ = util.get_image_paths(self.data_type, opt['dataroot_LQ'])
|
|
|
|
|
if 'dataroot_PIX' in opt:
|
|
|
|
|
if 'dataroot_PIX' in opt.keys():
|
|
|
|
|
self.paths_PIX, self.sizes_PIX = util.get_image_paths(self.data_type, opt['dataroot_PIX'])
|
|
|
|
|
|
|
|
|
|
assert self.paths_GT, 'Error: GT path is empty.'
|
|
|
|
|
if self.paths_LQ and self.paths_GT:
|
|
|
|
|
assert len(self.paths_LQ) == len(
|
|
|
|
@ -39,7 +41,7 @@ class LQGTDataset(data.Dataset):
|
|
|
|
|
meminit=False)
|
|
|
|
|
self.LQ_env = lmdb.open(self.opt['dataroot_LQ'], readonly=True, lock=False, readahead=False,
|
|
|
|
|
meminit=False)
|
|
|
|
|
if 'dataroot_PIX' in self.opt:
|
|
|
|
|
if 'dataroot_PIX' in self.opt.keys():
|
|
|
|
|
self.PIX_env = lmdb.open(self.opt['dataroot_PIX'], readonly=True, lock=False, readahead=False,
|
|
|
|
|
meminit=False)
|
|
|
|
|
|
|
|
|
@ -61,12 +63,13 @@ class LQGTDataset(data.Dataset):
|
|
|
|
|
img_GT = util.channel_convert(img_GT.shape[2], self.opt['color'], [img_GT])[0]
|
|
|
|
|
|
|
|
|
|
# get the pix image
|
|
|
|
|
if self.PIX_path is not None:
|
|
|
|
|
PIX_path = self.PIX_path[index]
|
|
|
|
|
if self.paths_PIX is not None:
|
|
|
|
|
PIX_path = self.paths_PIX[index]
|
|
|
|
|
img_PIX = util.read_img(self.PIX_env, PIX_path, resolution)
|
|
|
|
|
if self.opt['color']: # change color space if necessary
|
|
|
|
|
img_PIX = util.channel_convert(img_PIX.shape[2], self.opt['color'], [img_PIX])[0]
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
img_PIX = img_GT
|
|
|
|
|
|
|
|
|
|
# get LQ image
|
|
|
|
|
if self.paths_LQ:
|
|
|
|
@ -98,14 +101,8 @@ class LQGTDataset(data.Dataset):
|
|
|
|
|
img_LQ = np.expand_dims(img_LQ, axis=2)
|
|
|
|
|
|
|
|
|
|
if self.opt['phase'] == 'train':
|
|
|
|
|
# if the image size is too small
|
|
|
|
|
H, W, _ = img_GT.shape
|
|
|
|
|
if H < GT_size or W < GT_size:
|
|
|
|
|
img_GT = cv2.resize(img_GT, (GT_size, GT_size), interpolation=cv2.INTER_LINEAR)
|
|
|
|
|
# using matlab imresize
|
|
|
|
|
img_LQ = util.imresize_np(img_GT, 1 / scale, True)
|
|
|
|
|
if img_LQ.ndim == 2:
|
|
|
|
|
img_LQ = np.expand_dims(img_LQ, axis=2)
|
|
|
|
|
assert H >= GT_size and W >= GT_size
|
|
|
|
|
|
|
|
|
|
H, W, C = img_LQ.shape
|
|
|
|
|
LQ_size = GT_size // scale
|
|
|
|
@ -116,9 +113,10 @@ class LQGTDataset(data.Dataset):
|
|
|
|
|
img_LQ = img_LQ[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :]
|
|
|
|
|
rnd_h_GT, rnd_w_GT = int(rnd_h * scale), int(rnd_w * scale)
|
|
|
|
|
img_GT = img_GT[rnd_h_GT:rnd_h_GT + GT_size, rnd_w_GT:rnd_w_GT + GT_size, :]
|
|
|
|
|
img_PIX = img_PIX[rnd_h_GT:rnd_h_GT + GT_size, rnd_w_GT:rnd_w_GT + GT_size, :]
|
|
|
|
|
|
|
|
|
|
# augmentation - flip, rotate
|
|
|
|
|
img_LQ, img_GT = util.augment([img_LQ, img_GT], self.opt['use_flip'],
|
|
|
|
|
img_LQ, img_GT, img_PIX = util.augment([img_LQ, img_GT, img_PIX], self.opt['use_flip'],
|
|
|
|
|
self.opt['use_rot'])
|
|
|
|
|
|
|
|
|
|
if self.opt['color']: # change color space if necessary
|
|
|
|
@ -129,12 +127,14 @@ class LQGTDataset(data.Dataset):
|
|
|
|
|
if img_GT.shape[2] == 3:
|
|
|
|
|
img_GT = img_GT[:, :, [2, 1, 0]]
|
|
|
|
|
img_LQ = img_LQ[:, :, [2, 1, 0]]
|
|
|
|
|
img_PIX = img_PIX[:, :, [2, 1, 0]]
|
|
|
|
|
img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float()
|
|
|
|
|
img_PIX = torch.from_numpy(np.ascontiguousarray(np.transpose(img_PIX, (2, 0, 1)))).float()
|
|
|
|
|
img_LQ = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LQ, (2, 0, 1)))).float()
|
|
|
|
|
|
|
|
|
|
if LQ_path is None:
|
|
|
|
|
LQ_path = GT_path
|
|
|
|
|
return {'LQ': img_LQ, 'GT': img_GT, 'LQ_path': LQ_path, 'GT_path': GT_path}
|
|
|
|
|
return {'LQ': img_LQ, 'GT': img_GT, 'PIX': img_PIX, 'LQ_path': LQ_path, 'GT_path': GT_path}
|
|
|
|
|
|
|
|
|
|
def __len__(self):
|
|
|
|
|
return len(self.paths_GT)
|
|
|
|
|