diff --git a/codes/data/image_folder_dataset.py b/codes/data/image_folder_dataset.py index db51c383..fddced48 100644 --- a/codes/data/image_folder_dataset.py +++ b/codes/data/image_folder_dataset.py @@ -32,6 +32,9 @@ class ImageFolderDataset: self.corrupt_before_downsize = opt['corrupt_before_downsize'] if 'corrupt_before_downsize' in opt.keys() else False self.fetch_alt_image = opt['fetch_alt_image'] # If specified, this dataset will attempt to find a second image # from the same video source. Search for 'fetch_alt_image' for more info. + self.fetch_alt_tiled_image = opt['fetch_alt_tiled_image'] # If specified, this dataset will attempt to find anoter tile from the same source image + # Search for 'fetch_alt_tiled_image' for more info. + assert not (self.fetch_alt_image and self.fetch_alt_tiled_image) # These are mutually exclusive. self.skip_lq = opt_get(opt, ['skip_lq'], False) self.disable_flip = opt_get(opt, ['disable_flip'], False) self.rgb_n1_to_1 = opt_get(opt, ['rgb_n1_to_1'], False) @@ -188,6 +191,25 @@ class ImageFolderDataset: if not self.skip_lq: for_lq.append(hs[0]) out_dict['alt_hq'] = alt_hq + + if self.fetch_alt_tiled_image: + # This assumes the output format generated by the tiled image generation scripts included with DLAS. Specifically, + # all image read by this dataset are assumed to be in subfolders with other tiles from the same source image. When + # this option is set, another random image from the same folder is selected and returned as the alt image. + sel_path = self.image_paths[item] + other_images = random.shuffle(os.listdir(sel_path)) + # Assume that the directory contains at least , , + if len(other_images) <= 3: + alt_hq = hq # This is a fallback in case an alt image can't be found. + else: + for oi in other_images: + if oi == sel_path or 'ref.' in oi or 'centers.pt' in oi: + continue + alt_hq = util.read_img(None, oi, rgb=True) + alt_hs = self.resize_hq([alt_hq]) + alt_hq = torch.from_numpy(np.ascontiguousarray(np.transpose(alt_hs[0], (2, 0, 1)))).float() + out_dict['has_alt'] = True + out_dict['alt_hq'] = alt_hq if not self.skip_lq: lqs, ent = self.synthesize_lq(for_lq) @@ -244,4 +266,4 @@ if __name__ == '__main__': lq = d['lq'] #torchvision.utils.save_image(lq[:,:,16:-16,:], f'{output_path}\\{i+500000}.png') if i >= 200000: - break \ No newline at end of file + break