From 5c1832e124e8827fc971f36bab0b0e52636061da Mon Sep 17 00:00:00 2001 From: James Betker Date: Wed, 6 May 2020 17:24:17 -0600 Subject: [PATCH] Add support for multiple LQ paths I want to be able to specify many different transformations onto the target data; the model should handle them all. Do this by allowing multiple LQ paths to be selected and the dataset class selects one at random. --- codes/data/LQGT_dataset.py | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/codes/data/LQGT_dataset.py b/codes/data/LQGT_dataset.py index 31a292bc..2d54cda7 100644 --- a/codes/data/LQGT_dataset.py +++ b/codes/data/LQGT_dataset.py @@ -13,6 +13,10 @@ class LQGTDataset(data.Dataset): If only GT images are provided, generate LQ images on-the-fly. """ + def get_lq_path(self, i): + which_lq = random.randint(0, len(self.paths_LQ)) + return self.paths_LQ[which_lq][i] + def __init__(self, opt): super(LQGTDataset, self).__init__() self.opt = opt @@ -23,17 +27,26 @@ class LQGTDataset(data.Dataset): self.LQ_env, self.GT_env, self.PIX_env = None, None, None # environments for lmdbs self.paths_GT, self.sizes_GT = util.get_image_paths(self.data_type, opt['dataroot_GT']) - self.paths_LQ, self.sizes_LQ = util.get_image_paths(self.data_type, opt['dataroot_LQ']) + self.paths_LQ = [] + if isinstance(opt['dataroot_LQ'], list): + # Multiple LQ data sources can be given, in case there are multiple ways of corrupting a source image and + # we want the model to learn them all. + for dr_lq in opt['dataroot_LQ']: + lq_path, self.sizes_LQ = util.get_image_paths(self.data_type, dr_lq) + self.paths_LQ.append(lq_path) + else: + lq_path, self.sizes_LQ = util.get_image_paths(self.data_type, opt['dataroot_LQ']) + self.paths_LQ.append(lq_path) self.doCrop = opt['doCrop'] if 'dataroot_PIX' in opt.keys(): self.paths_PIX, self.sizes_PIX = util.get_image_paths(self.data_type, opt['dataroot_PIX']) assert self.paths_GT, 'Error: GT path is empty.' if self.paths_LQ and self.paths_GT: - assert len(self.paths_LQ) == len( + assert len(self.paths_LQ[0]) == len( self.paths_GT ), 'GT and LQ datasets have different number of images - {}, {}.'.format( - len(self.paths_LQ), len(self.paths_GT)) + len(self.paths_LQ[0]), len(self.paths_GT)) self.random_scale_list = [1] def _init_lmdb(self): @@ -74,7 +87,7 @@ class LQGTDataset(data.Dataset): # get LQ image if self.paths_LQ: - LQ_path = self.paths_LQ[index] + LQ_path = self.get_lq_path(index) resolution = [int(s) for s in self.sizes_LQ[index].split('_') ] if self.data_type == 'lmdb' else None img_LQ = util.read_img(self.LQ_env, LQ_path, resolution)