import random import numpy as np import cv2 import lmdb import torch import torch.utils.data as data import data.util as util from PIL import Image from io import BytesIO import torchvision.transforms.functional as F class LQGTDataset(data.Dataset): """ Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, etc) and GT image pairs. If only GT images are provided, generate LQ images on-the-fly. """ def get_lq_path(self, i): which_lq = random.randint(0, len(self.paths_LQ)-1) return self.paths_LQ[which_lq][i] def __init__(self, opt): super(LQGTDataset, self).__init__() self.opt = opt self.data_type = self.opt['data_type'] self.paths_LQ, self.paths_GT = None, None self.sizes_LQ, self.sizes_GT = None, None self.paths_PIX, self.sizes_PIX = None, None self.LQ_env, self.GT_env, self.PIX_env = None, None, None # environments for lmdbs self.paths_GT, self.sizes_GT = util.get_image_paths(self.data_type, opt['dataroot_GT']) if 'dataroot_LQ' in opt.keys(): self.paths_LQ = [] if isinstance(opt['dataroot_LQ'], list): # Multiple LQ data sources can be given, in case there are multiple ways of corrupting a source image and # we want the model to learn them all. for dr_lq in opt['dataroot_LQ']: lq_path, self.sizes_LQ = util.get_image_paths(self.data_type, dr_lq) self.paths_LQ.append(lq_path) else: lq_path, self.sizes_LQ = util.get_image_paths(self.data_type, opt['dataroot_LQ']) self.paths_LQ.append(lq_path) self.doCrop = opt['doCrop'] if 'dataroot_PIX' in opt.keys(): self.paths_PIX, self.sizes_PIX = util.get_image_paths(self.data_type, opt['dataroot_PIX']) assert self.paths_GT, 'Error: GT path is empty.' if self.paths_LQ and self.paths_GT: assert len(self.paths_LQ[0]) == len( self.paths_GT ), 'GT and LQ datasets have different number of images - {}, {}.'.format( len(self.paths_LQ[0]), len(self.paths_GT)) self.random_scale_list = [1] def _init_lmdb(self): # https://github.com/chainer/chainermn/issues/129 self.GT_env = lmdb.open(self.opt['dataroot_GT'], readonly=True, lock=False, readahead=False, meminit=False) self.LQ_env = lmdb.open(self.opt['dataroot_LQ'], readonly=True, lock=False, readahead=False, meminit=False) if 'dataroot_PIX' in self.opt.keys(): self.PIX_env = lmdb.open(self.opt['dataroot_PIX'], readonly=True, lock=False, readahead=False, meminit=False) def __getitem__(self, index): if self.data_type == 'lmdb' and (self.GT_env is None or self.LQ_env is None): self._init_lmdb() GT_path, LQ_path = None, None scale = self.opt['scale'] GT_size = self.opt['target_size'] # get GT image GT_path = self.paths_GT[index] resolution = [int(s) for s in self.sizes_GT[index].split('_') ] if self.data_type == 'lmdb' else None img_GT = util.read_img(self.GT_env, GT_path, resolution) if self.opt['phase'] != 'train': # modcrop in the validation / test phase img_GT = util.modcrop(img_GT, scale) if self.opt['color']: # change color space if necessary img_GT = util.channel_convert(img_GT.shape[2], self.opt['color'], [img_GT])[0] # get the pix image if self.paths_PIX is not None: PIX_path = self.paths_PIX[index] img_PIX = util.read_img(self.PIX_env, PIX_path, resolution) if self.opt['color']: # change color space if necessary img_PIX = util.channel_convert(img_PIX.shape[2], self.opt['color'], [img_PIX])[0] else: img_PIX = img_GT # get LQ image already_blurred = False if self.paths_LQ: LQ_path = self.get_lq_path(index) resolution = [int(s) for s in self.sizes_LQ[index].split('_') ] if self.data_type == 'lmdb' else None img_LQ = util.read_img(self.LQ_env, LQ_path, resolution) else: # down-sampling on-the-fly # randomly scale during training if self.opt['phase'] == 'train': random_scale = random.choice(self.random_scale_list) H_s, W_s, _ = img_GT.shape def _mod(n, random_scale, scale, thres): rlt = int(n * random_scale) rlt = (rlt // scale) * scale return thres if rlt < thres else rlt H_s = _mod(H_s, random_scale, scale, GT_size) W_s = _mod(W_s, random_scale, scale, GT_size) img_GT = cv2.resize(img_GT, (W_s, H_s), interpolation=cv2.INTER_LINEAR) if img_GT.ndim == 2: img_GT = cv2.cvtColor(img_GT, cv2.COLOR_GRAY2BGR) H, W, _ = img_GT.shape if self.opt['use_blurring']: blur_sig = int(random.randrange(0, 4)) hqxform = cv2.GaussianBlur(img_GT, (3, 3), blur_sig) already_blurred = True else: hqxform = img_GT # using matlab imresize img_LQ = util.imresize_np(hqxform, 1 / scale, True) if img_LQ.ndim == 2: img_LQ = np.expand_dims(img_LQ, axis=2) if self.opt['phase'] == 'train': H, W, _ = img_GT.shape assert H >= GT_size and W >= GT_size H, W, C = img_LQ.shape LQ_size = GT_size // scale if self.doCrop: # randomly crop rnd_h = random.randint(0, max(0, H - LQ_size)) rnd_w = random.randint(0, max(0, W - LQ_size)) img_LQ = img_LQ[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :] rnd_h_GT, rnd_w_GT = int(rnd_h * scale), int(rnd_w * scale) img_GT = img_GT[rnd_h_GT:rnd_h_GT + GT_size, rnd_w_GT:rnd_w_GT + GT_size, :] img_PIX = img_PIX[rnd_h_GT:rnd_h_GT + GT_size, rnd_w_GT:rnd_w_GT + GT_size, :] else: img_LQ = cv2.resize(img_LQ, (LQ_size, LQ_size), interpolation=cv2.INTER_LINEAR) img_GT = cv2.resize(img_GT, (GT_size, GT_size), interpolation=cv2.INTER_LINEAR) img_PIX = cv2.resize(img_PIX, (GT_size, GT_size), interpolation=cv2.INTER_LINEAR) # augmentation - flip, rotate img_LQ, img_GT, img_PIX = util.augment([img_LQ, img_GT, img_PIX], self.opt['use_flip'], self.opt['use_rot']) if self.opt['use_blurring'] and not already_blurred: blur_sig = int(random.randrange(0, 4)) img_LQ = cv2.GaussianBlur(img_LQ, (3, 3), blur_sig) if self.opt['color']: # change color space if necessary img_LQ = util.channel_convert(C, self.opt['color'], [img_LQ])[0] # TODO during val no definition # BGR to RGB, HWC to CHW, numpy to tensor if img_GT.shape[2] == 3: img_GT = cv2.cvtColor(img_GT, cv2.COLOR_BGR2RGB) img_LQ = cv2.cvtColor(img_LQ, cv2.COLOR_BGR2RGB) img_PIX = cv2.cvtColor(img_PIX, cv2.COLOR_BGR2RGB) # LQ needs to go to a PIL image to perform the compression-artifact transformation. img_LQ = (img_LQ * 255).astype(np.uint8) img_LQ = Image.fromarray(img_LQ) if self.opt['use_compression_artifacts']: qf = random.randrange(15, 100) corruption_buffer = BytesIO() img_LQ.save(corruption_buffer, "JPEG", quality=qf, optimice=True) corruption_buffer.seek(0) img_LQ = Image.open(corruption_buffer) img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float() img_PIX = torch.from_numpy(np.ascontiguousarray(np.transpose(img_PIX, (2, 0, 1)))).float() img_LQ = F.to_tensor(img_LQ) if LQ_path is None: LQ_path = GT_path return {'LQ': img_LQ, 'GT': img_GT, 'PIX': img_PIX, 'LQ_path': LQ_path, 'GT_path': GT_path} def __len__(self): return len(self.paths_GT)