Allow more LQ than GT images in corrupt mode

This commit is contained in:
James Betker 2020-05-14 20:46:20 -06:00
parent 8a514b9645
commit d72e154442

View File

@ -49,7 +49,7 @@ class DownsampleDataset(data.Dataset):
GT_size = self.opt['target_size'] * scale GT_size = self.opt['target_size'] * scale
# get GT image # get GT image
GT_path = self.paths_GT[index] GT_path = self.paths_GT[index % len(self.paths_GT)]
resolution = [int(s) for s in self.sizes_GT[index].split('_') resolution = [int(s) for s in self.sizes_GT[index].split('_')
] if self.data_type == 'lmdb' else None ] if self.data_type == 'lmdb' else None
img_GT = util.read_img(self.GT_env, GT_path, resolution) img_GT = util.read_img(self.GT_env, GT_path, resolution)
@ -59,7 +59,6 @@ class DownsampleDataset(data.Dataset):
img_GT = util.channel_convert(img_GT.shape[2], self.opt['color'], [img_GT])[0] img_GT = util.channel_convert(img_GT.shape[2], self.opt['color'], [img_GT])[0]
# get LQ image # get LQ image
lqind = index % len(self.paths_LQ)
LQ_path = self.paths_LQ[index % len(self.paths_LQ)] LQ_path = self.paths_LQ[index % len(self.paths_LQ)]
resolution = [int(s) for s in self.sizes_LQ[index].split('_') resolution = [int(s) for s in self.sizes_LQ[index].split('_')
] if self.data_type == 'lmdb' else None ] if self.data_type == 'lmdb' else None
@ -116,4 +115,4 @@ class DownsampleDataset(data.Dataset):
return {'LQ': img_GT, 'GT': img_LQ, 'PIX': img_Downsampled, 'LQ_path': LQ_path, 'GT_path': GT_path} return {'LQ': img_GT, 'GT': img_LQ, 'PIX': img_Downsampled, 'LQ_path': LQ_path, 'GT_path': GT_path}
def __len__(self): def __len__(self):
return len(self.paths_GT) return max(len(self.paths_GT), len(self.paths_LQ))