Fix dataset for a val set that includes lq

This commit is contained in:
James Betker 2020-09-24 18:01:07 -06:00
parent ea565b7eaf
commit 1cf73c2cce

View File

@ -232,9 +232,14 @@ class FullImageDataset(data.Dataset):
if self.paths_LQ:
LQ_path = self.get_lq_path(index)
img_lq_full = util.read_img(None, LQ_path, None)
img_lq_full = util.augment([img_lq_full], self.opt['use_flip'], self.opt['use_rot'])[0]
img_lq_full = self.get_square_image(img_lq_full)
img_LQ, lq_fullsize_ref, lq_mask, lq_center = self.pull_tile(img_lq_full, lq=True)
if self.opt['phase'] == 'train':
img_lq_full = util.augment([img_lq_full], self.opt['use_flip'], self.opt['use_rot'])[0]
img_lq_full = self.get_square_image(img_lq_full)
img_LQ, lq_fullsize_ref, lq_mask, lq_center = self.pull_tile(img_lq_full, lq=True)
else:
img_LQ, lq_fullsize_ref = img_lq_full, img_lq_full
lq_mask = np.ones(img_lq_full.shape[:2], dtype=lq_fullsize_ref.dtype)
lq_center = torch.tensor([img_lq_full.shape[0] // 2, img_lq_full.shape[1] // 2], dtype=torch.long)
else: # down-sampling on-the-fly
# randomly scale during training
if self.opt['phase'] == 'train':