Update image_folder_dataset.py

This commit is contained in:
James Betker 2021-06-25 11:48:31 -06:00 committed by GitHub
parent a0ef07ddb8
commit 61e7ca39cd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -32,6 +32,9 @@ class ImageFolderDataset:
self.corrupt_before_downsize = opt['corrupt_before_downsize'] if 'corrupt_before_downsize' in opt.keys() else False
self.fetch_alt_image = opt['fetch_alt_image'] # If specified, this dataset will attempt to find a second image
# from the same video source. Search for 'fetch_alt_image' for more info.
self.fetch_alt_tiled_image = opt['fetch_alt_tiled_image'] # If specified, this dataset will attempt to find anoter tile from the same source image
# Search for 'fetch_alt_tiled_image' for more info.
assert not (self.fetch_alt_image and self.fetch_alt_tiled_image) # These are mutually exclusive.
self.skip_lq = opt_get(opt, ['skip_lq'], False)
self.disable_flip = opt_get(opt, ['disable_flip'], False)
self.rgb_n1_to_1 = opt_get(opt, ['rgb_n1_to_1'], False)
@ -188,6 +191,25 @@ class ImageFolderDataset:
if not self.skip_lq:
for_lq.append(hs[0])
out_dict['alt_hq'] = alt_hq
if self.fetch_alt_tiled_image:
# This assumes the output format generated by the tiled image generation scripts included with DLAS. Specifically,
# all image read by this dataset are assumed to be in subfolders with other tiles from the same source image. When
# this option is set, another random image from the same folder is selected and returned as the alt image.
sel_path = self.image_paths[item]
other_images = random.shuffle(os.listdir(sel_path))
# Assume that the directory contains at least <image>, <ref.jpg>, <centers.pt>
if len(other_images) <= 3:
alt_hq = hq # This is a fallback in case an alt image can't be found.
else:
for oi in other_images:
if oi == sel_path or 'ref.' in oi or 'centers.pt' in oi:
continue
alt_hq = util.read_img(None, oi, rgb=True)
alt_hs = self.resize_hq([alt_hq])
alt_hq = torch.from_numpy(np.ascontiguousarray(np.transpose(alt_hs[0], (2, 0, 1)))).float()
out_dict['has_alt'] = True
out_dict['alt_hq'] = alt_hq
if not self.skip_lq:
lqs, ent = self.synthesize_lq(for_lq)
@ -244,4 +266,4 @@ if __name__ == '__main__':
lq = d['lq']
#torchvision.utils.save_image(lq[:,:,16:-16,:], f'{output_path}\\{i+500000}.png')
if i >= 200000:
break
break