diff --git a/codes/data/image_folder_dataset.py b/codes/data/image_folder_dataset.py index fe334eab..7637d864 100644 --- a/codes/data/image_folder_dataset.py +++ b/codes/data/image_folder_dataset.py @@ -12,6 +12,7 @@ from data import util # Builds a dataset created from a simple folder containing a list of training/test/validation images. from data.image_corruptor import ImageCorruptor from data.image_label_parser import VsNetImageLabeler +from utils.util import opt_get class ImageFolderDataset: @@ -25,8 +26,8 @@ class ImageFolderDataset: self.corrupt_before_downsize = opt['corrupt_before_downsize'] if 'corrupt_before_downsize' in opt.keys() else False self.fetch_alt_image = opt['fetch_alt_image'] # If specified, this dataset will attempt to find a second image # from the same video source. Search for 'fetch_alt_image' for more info. - self.skip_lq = opt['skip_lq'] - self.disable_flip = opt['disable_flip'] + self.skip_lq = opt_get(opt, ['skip_lq'], False) + self.disable_flip = opt_get(opt, ['disable_flip'], False) assert (self.target_hq_size // self.scale) % self.multiple == 0 # If we dont throw here, we get some really obscure errors. if not isinstance(self.paths, list): self.paths = [self.paths] @@ -119,16 +120,12 @@ class ImageFolderDataset: hs = self.resize_hq([hq]) if not self.skip_lq: - ls = self.synthesize_lq(hs) + for_lq = [hs[0]] # Convert to torch tensor hq = torch.from_numpy(np.ascontiguousarray(np.transpose(hs[0], (2, 0, 1)))).float() - if not self.skip_lq: - lq = torch.from_numpy(np.ascontiguousarray(np.transpose(ls[0], (2, 0, 1)))).float() out_dict = {'hq': hq, 'LQ_path': self.image_paths[item], 'HQ_path': self.image_paths[item]} - if not self.skip_lq: - out_dict['lq'] = lq if self.fetch_alt_image: # This works by assuming a specific filename structure as would produced by ffmpeg. ex: @@ -151,22 +148,28 @@ class ImageFolderDataset: # the file rather than searching the path list. Let the exception handler below do its work. next_img = self.image_paths[item].replace(str(imnumber), str(imnumber+1)) alt_hq = util.read_img(None, next_img, rgb=True) - alt_hq = self.resize_hq([alt_hq]) - alt_hq = torch.from_numpy(np.ascontiguousarray(np.transpose(alt_hq[0], (2, 0, 1)))).float() + alt_hs = self.resize_hq([alt_hq]) + alt_hq = torch.from_numpy(np.ascontiguousarray(np.transpose(alt_hs[0], (2, 0, 1)))).float() if not self.skip_lq: - alt_lq = self.synthesize_lq(alt_hq) - alt_lq = torch.from_numpy(np.ascontiguousarray(np.transpose(alt_lq[0], (2, 0, 1)))).float() + for_lq.append(alt_hs[0]) except: alt_hq = hq if not self.skip_lq: - alt_lq = lq + for_lq.append(hs[0]) else: alt_hq = hq if not self.skip_lq: - alt_lq = lq + for_lq.append(hs[0]) out_dict['alt_hq'] = alt_hq - if not self.skip_lq: - out_dict['alt_lq'] = alt_lq + + if not self.skip_lq: + lqs = self.synthesize_lq(for_lq) + ls = lqs[0] + out_dict['lq'] = torch.from_numpy(np.ascontiguousarray(np.transpose(ls, (2, 0, 1)))).float() + if len(lqs) > 1: + alt_lq = lqs[1] + out_dict['alt_lq'] = torch.from_numpy(np.ascontiguousarray(np.transpose(alt_lq, (2, 0, 1)))).float() + if self.labeler: base_file = self.image_paths[item].replace(self.paths[0], "") @@ -190,7 +193,7 @@ if __name__ == '__main__': 'fixed_corruptions': ['jpeg-broad', 'gaussian_blur'], 'random_corruptions': ['noise-5', 'none'], 'num_corrupts_per_image': 1, - 'corrupt_before_downsize': True, + 'corrupt_before_downsize': False, 'fetch_alt_image': True, #'labeler': { # 'type': 'patch_labels', @@ -203,11 +206,11 @@ if __name__ == '__main__': os.makedirs("debug", exist_ok=True) for i in range(0, len(ds)): o = ds[random.randint(0, len(ds)-1)] - hq = o['hq'] + hq = o['lq'] #masked = (o['labels_mask'] * .5 + .5) * hq import torchvision - torchvision.utils.save_image(hq.unsqueeze(0), "debug/%i_hq.png" % (i,)) - torchvision.utils.save_image(o['alt_hq'].unsqueeze(0), "debug/%i_hq_alt.png" % (i,)) + torchvision.utils.save_image(hq.unsqueeze(0), "debug/%i_lq.png" % (i,)) + torchvision.utils.save_image(o['alt_lq'].unsqueeze(0), "debug/%i_lq_alt.png" % (i,)) #if len(o['labels'].unique()) > 1: # randlbl = np.random.choice(o['labels'].unique()[1:]) # moremask = hq * ((1*(o['labels'] == randlbl))*.5+.5)