diff --git a/codes/data/Downsample_dataset.py b/codes/data/Downsample_dataset.py index dc73e73d..5e950d05 100644 --- a/codes/data/Downsample_dataset.py +++ b/codes/data/Downsample_dataset.py @@ -49,7 +49,7 @@ class DownsampleDataset(data.Dataset): GT_size = self.opt['target_size'] * scale # get GT image - GT_path = self.paths_GT[index] + GT_path = self.paths_GT[index % len(self.paths_GT)] resolution = [int(s) for s in self.sizes_GT[index].split('_') ] if self.data_type == 'lmdb' else None img_GT = util.read_img(self.GT_env, GT_path, resolution) @@ -59,7 +59,6 @@ class DownsampleDataset(data.Dataset): img_GT = util.channel_convert(img_GT.shape[2], self.opt['color'], [img_GT])[0] # get LQ image - lqind = index % len(self.paths_LQ) LQ_path = self.paths_LQ[index % len(self.paths_LQ)] resolution = [int(s) for s in self.sizes_LQ[index].split('_') ] if self.data_type == 'lmdb' else None @@ -116,4 +115,4 @@ class DownsampleDataset(data.Dataset): return {'LQ': img_GT, 'GT': img_LQ, 'PIX': img_Downsampled, 'LQ_path': LQ_path, 'GT_path': GT_path} def __len__(self): - return len(self.paths_GT) + return max(len(self.paths_GT), len(self.paths_LQ))