diff --git a/codes/data/image_folder_dataset.py b/codes/data/image_folder_dataset.py index 1ee85dab..1c21c920 100644 --- a/codes/data/image_folder_dataset.py +++ b/codes/data/image_folder_dataset.py @@ -10,6 +10,7 @@ import os from data import util # Builds a dataset created from a simple folder containing a list of training/test/validation images. from data.image_corruptor import ImageCorruptor +from data.image_label_parser import VsNetImageLabeler class ImageFolderDataset: @@ -28,21 +29,29 @@ class ImageFolderDataset: else: self.weights = opt['weights'] - # Just scan the given directory for images of standard types. - supported_types = ['jpg', 'jpeg', 'png', 'gif'] - self.image_paths = [] - for path, weight in zip(self.paths, self.weights): - cache_path = os.path.join(path, 'cache.pth') - if os.path.exists(cache_path): - imgs = torch.load(cache_path) - else: - print("Building image folder cache, this can take some time for large datasets..") - imgs = [] - for ext in supported_types: - imgs.extend(glob.glob(os.path.join(path, "*." + ext))) - torch.save(imgs, cache_path) - for w in range(weight): - self.image_paths.extend(imgs) + if 'labeler' in opt.keys(): + if opt['labeler']['type'] == 'patch_labels': + self.labeler = VsNetImageLabeler(opt['labeler']['label_file']) + assert len(self.paths) == 1 # Only a single base-path is supported for labeled images. + self.image_paths = self.labeler.get_labeled_paths(self.paths[0]) + else: + self.labeler = None + + # Just scan the given directory for images of standard types. + supported_types = ['jpg', 'jpeg', 'png', 'gif'] + self.image_paths = [] + for path, weight in zip(self.paths, self.weights): + cache_path = os.path.join(path, 'cache.pth') + if os.path.exists(cache_path): + imgs = torch.load(cache_path) + else: + print("Building image folder cache, this can take some time for large datasets..") + imgs = [] + for ext in supported_types: + imgs.extend(glob.glob(os.path.join(path, "*." + ext))) + torch.save(imgs, cache_path) + for w in range(weight): + self.image_paths.extend(imgs) self.len = len(self.image_paths) def get_paths(self): @@ -74,6 +83,7 @@ class ImageFolderDataset: h, w, _ = hs[0].shape ls = [] if self.corrupt_before_downsize: + hs = [h.copy() for h in hs] hs = self.corruptor.corrupt_images(hs) for hq in hs: ls.append(cv2.resize(hq, (h // self.scale, w // self.scale), interpolation=cv2.INTER_AREA)) @@ -87,6 +97,9 @@ class ImageFolderDataset: def __getitem__(self, item): hq = util.read_img(None, self.image_paths[item], rgb=True) + if self.labeler: + assert hq.shape[0] == hq.shape[1] # This just has not been accomodated yet. + dim = hq.shape[0] hs = self.resize_hq([hq]) ls = self.synthesize_lq(hs) @@ -95,13 +108,19 @@ class ImageFolderDataset: hq = torch.from_numpy(np.ascontiguousarray(np.transpose(hs[0], (2, 0, 1)))).float() lq = torch.from_numpy(np.ascontiguousarray(np.transpose(ls[0], (2, 0, 1)))).float() - return {'lq': lq, 'hq': hq, 'LQ_path': self.image_paths[item], 'HQ_path': self.image_paths[item]} - + out_dict = {'lq': lq, 'hq': hq, 'LQ_path': self.image_paths[item], 'HQ_path': self.image_paths[item]} + if self.labeler: + base_file = self.image_paths[item].replace(self.paths[0], "") + assert dim % hq.shape[1] == 0 + lbls, lbl_masks = self.labeler.get_labels_as_tensor(hq, base_file, dim // hq.shape[1]) + out_dict['labels'] = lbls + out_dict['labels_mask'] = lbl_masks + return out_dict if __name__ == '__main__': opt = { 'name': 'amalgam', - 'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\random_100_1024px'], + 'paths': ['F:\\4k6k\\datasets\\ns_images\\512_unsupervised\\'], 'weights': [1], 'target_size': 128, 'force_multiple': 32, @@ -110,20 +129,19 @@ if __name__ == '__main__': 'random_corruptions': ['noise-5', 'none'], 'num_corrupts_per_image': 1, 'corrupt_before_downsize': True, + 'labeler': { + 'type': 'patch_labels', + 'label_file': 'F:\\4k6k\\datasets\\ns_images\\512_unsupervised\\categories.json' + } } ds = ImageFolderDataset(opt) import os os.makedirs("debug", exist_ok=True) for i in range(0, len(ds)): - o = ds[random.randint(0, len(ds))] - #for k, v in o.items(): - k = 'lq' - v = o[k] - #if 'LQ' in k and 'path' not in k and 'center' not in k: - #if 'full' in k: - #masked = v[:3, :, :] * v[3] - #torchvision.utils.save_image(masked.unsqueeze(0), "debug/%i_%s_masked.png" % (i, k)) - #v = v[:3, :, :] + o = ds[random.randint(0, len(ds)-1)] + hq = o['hq'] + masked = (o['labels_mask'] * .5 + .5) * hq import torchvision - torchvision.utils.save_image(v.unsqueeze(0), "debug/%i_%s.png" % (i, k)) \ No newline at end of file + torchvision.utils.save_image(hq.unsqueeze(0), "debug/%i_hq.png" % (i,)) + torchvision.utils.save_image(masked.unsqueeze(0), "debug/%i_masked.png" % (i,)) \ No newline at end of file diff --git a/codes/data/image_label_parser.py b/codes/data/image_label_parser.py new file mode 100644 index 00000000..241746c6 --- /dev/null +++ b/codes/data/image_label_parser.py @@ -0,0 +1,58 @@ +import os + +import orjson as json +# Given a JSON file produced by the VS.net image labeler utility, produces a dict where the keys are image file names +# and the values are a list of object with the following properties: +# [patch_top, patch_left, patch_height, patch_width, label] +import torch + + +class VsNetImageLabeler: + def __init__(self, label_file): + with open(label_file, "r") as read_file: + # Format of JSON file: + # "" { + # "label": "