diff --git a/codes/data/full_image_dataset.py b/codes/data/full_image_dataset.py index cf2fe8a1..84893699 100644 --- a/codes/data/full_image_dataset.py +++ b/codes/data/full_image_dataset.py @@ -232,9 +232,14 @@ class FullImageDataset(data.Dataset): if self.paths_LQ: LQ_path = self.get_lq_path(index) img_lq_full = util.read_img(None, LQ_path, None) - img_lq_full = util.augment([img_lq_full], self.opt['use_flip'], self.opt['use_rot'])[0] - img_lq_full = self.get_square_image(img_lq_full) - img_LQ, lq_fullsize_ref, lq_mask, lq_center = self.pull_tile(img_lq_full, lq=True) + if self.opt['phase'] == 'train': + img_lq_full = util.augment([img_lq_full], self.opt['use_flip'], self.opt['use_rot'])[0] + img_lq_full = self.get_square_image(img_lq_full) + img_LQ, lq_fullsize_ref, lq_mask, lq_center = self.pull_tile(img_lq_full, lq=True) + else: + img_LQ, lq_fullsize_ref = img_lq_full, img_lq_full + lq_mask = np.ones(img_lq_full.shape[:2], dtype=lq_fullsize_ref.dtype) + lq_center = torch.tensor([img_lq_full.shape[0] // 2, img_lq_full.shape[1] // 2], dtype=torch.long) else: # down-sampling on-the-fly # randomly scale during training if self.opt['phase'] == 'train':