diff --git a/codes/data/LQGT_dataset.py b/codes/data/LQGT_dataset.py index 83bc6a54..99fed582 100644 --- a/codes/data/LQGT_dataset.py +++ b/codes/data/LQGT_dataset.py @@ -28,6 +28,7 @@ class LQGTDataset(data.Dataset): self.sizes_LQ, self.sizes_GT = None, None self.paths_PIX, self.sizes_PIX = None, None self.LQ_env, self.GT_env, self.PIX_env = None, None, None # environments for lmdbs + self.force_multiple = self.opt['force_multiple'] if 'force_multiple' in self.opt.keys() else 1 self.paths_GT, self.sizes_GT = util.get_image_paths(self.data_type, opt['dataroot_GT'], opt['dataroot_GT_weights']) if 'dataroot_LQ' in opt.keys(): @@ -126,6 +127,16 @@ class LQGTDataset(data.Dataset): if img_LQ.ndim == 2: img_LQ = np.expand_dims(img_LQ, axis=2) + # Enforce force_resize constraints. + h, w, _ = img_LQ.shape + if h % self.force_multiple != 0 or w % self.force_multiple != 0: + h, w = (w - w % self.force_multiple), (h - h % self.force_multiple) + img_LQ = cv2.resize(img_LQ, (h, w)) + h *= scale + w *= scale + img_GT = cv2.resize(img_GT, (h, w)) + img_PIX = cv2.resize(img_LQ, (h, w)) + if self.opt['phase'] == 'train': H, W, _ = img_GT.shape assert H >= GT_size and W >= GT_size