diff --git a/codes/data/LQGT_dataset.py b/codes/data/LQGT_dataset.py index 587456a6..d1fa1f50 100644 --- a/codes/data/LQGT_dataset.py +++ b/codes/data/LQGT_dataset.py @@ -90,7 +90,7 @@ class LQGTDataset(data.Dataset): # get the pix image if self.paths_PIX is not None: - PIX_path = self.paths_PIX[index] + PIX_path = self.paths_PIX[index % len(self.paths_PIX)] img_PIX = util.read_img(self.PIX_env, PIX_path, resolution) if self.opt['color']: # change color space if necessary img_PIX = util.channel_convert(img_PIX.shape[2], self.opt['color'], [img_PIX])[0] @@ -123,7 +123,10 @@ class LQGTDataset(data.Dataset): H, W, _ = img_GT.shape # using matlab imresize - img_LQ = util.imresize_np(img_GT, 1 / scale, True) + if scale == 1: + img_LQ = img_GT + else: + img_LQ = util.imresize_np(img_GT, 1 / scale, True) if img_LQ.ndim == 2: img_LQ = np.expand_dims(img_LQ, axis=2) @@ -160,11 +163,14 @@ class LQGTDataset(data.Dataset): img_GT = img_GT[rnd_h_GT:rnd_h_GT + GT_size, rnd_w_GT:rnd_w_GT + GT_size, :] img_PIX = img_PIX[rnd_h_GT:rnd_h_GT + GT_size, rnd_w_GT:rnd_w_GT + GT_size, :] else: - img_LQ = cv2.resize(img_LQ, (LQ_size, LQ_size), interpolation=cv2.INTER_LINEAR) + if img_LQ.shape[0] != LQ_size: + img_LQ = cv2.resize(img_LQ, (LQ_size, LQ_size), interpolation=cv2.INTER_LINEAR) if img_GAN is not None: img_GAN = cv2.resize(img_GAN, (LQ_size, LQ_size), interpolation=cv2.INTER_LINEAR) - img_GT = cv2.resize(img_GT, (GT_size, GT_size), interpolation=cv2.INTER_LINEAR) - img_PIX = cv2.resize(img_PIX, (GT_size, GT_size), interpolation=cv2.INTER_LINEAR) + if img_GT.shape[0] != GT_size: + img_GT = cv2.resize(img_GT, (GT_size, GT_size), interpolation=cv2.INTER_LINEAR) + if img_PIX.shape[0] != GT_size: + img_PIX = cv2.resize(img_PIX, (GT_size, GT_size), interpolation=cv2.INTER_LINEAR) if 'doResizeLoss' in self.opt.keys() and self.opt['doResizeLoss']: r = random.randrange(0, 10) @@ -224,7 +230,7 @@ class LQGTDataset(data.Dataset): if LQ_path is None: LQ_path = GT_path - d = {'LQ': img_LQ, 'GT': img_GT, 'PIX': img_PIX, 'LQ_path': LQ_path, 'GT_path': GT_path} + d = {'LQ': img_LQ, 'GT': img_GT, 'ref': img_PIX, 'LQ_path': LQ_path, 'GT_path': GT_path} if img_GAN is not None: d['GAN'] = img_GAN return d