diff --git a/codes/data/multiscale_dataset.py b/codes/data/multiscale_dataset.py index 0a3dc0b4..0a96eb20 100644 --- a/codes/data/multiscale_dataset.py +++ b/codes/data/multiscale_dataset.py @@ -10,6 +10,9 @@ import torchvision.transforms.functional as F # Reads full-quality images and pulls tiles at regular zoom intervals from them. Only usable for training purposes. +from data.image_corruptor import ImageCorruptor + + class MultiScaleDataset(data.Dataset): def __init__(self, opt): super(MultiScaleDataset, self).__init__() @@ -20,6 +23,7 @@ class MultiScaleDataset(data.Dataset): self.hq_size_cap = self.tile_size * 2 ** self.num_scales self.scale = self.opt['scale'] self.paths_hq, self.sizes_hq = util.get_image_paths(self.data_type, opt['paths'], [1]) + self.corruptor = ImageCorruptor(opt) # Selects the smallest dimension from the image and crops it randomly so the other dimension matches. The cropping # offset from center is chosen on a normal probability curve. @@ -62,13 +66,19 @@ class MultiScaleDataset(data.Dataset): img_full = cv2.resize(img_full, (self.hq_size_cap, self.hq_size_cap), interpolation=cv2.INTER_AREA) patches_hq = [cv2.resize(img_full, (self.tile_size, self.tile_size), interpolation=cv2.INTER_AREA)] self.recursively_extract_patches(img_full, patches_hq, 1) + # Image corruption is applied against the full size image for this dataset. + img_corrupted = self.corruptor.corrupt_images([img_full])[0] + patches_hq_corrupted = [cv2.resize(img_corrupted, (self.tile_size, self.tile_size), interpolation=cv2.INTER_AREA)] + self.recursively_extract_patches(img_corrupted, patches_hq_corrupted, 1) # BGR to RGB, HWC to CHW, numpy to tensor if patches_hq[0].shape[2] == 3: patches_hq = [cv2.cvtColor(p, cv2.COLOR_BGR2RGB) for p in patches_hq] + patches_hq_corrupted = [cv2.cvtColor(p, cv2.COLOR_BGR2RGB) for p in patches_hq_corrupted] patches_hq = [torch.from_numpy(np.ascontiguousarray(np.transpose(p, (2, 0, 1)))).float() for p in patches_hq] patches_hq = torch.stack(patches_hq, dim=0) - patches_lq = [torch.nn.functional.interpolate(p.unsqueeze(0), scale_factor=1/self.scale, mode='bilinear').squeeze() for p in patches_hq] + patches_hq_corrupted = [torch.from_numpy(np.ascontiguousarray(np.transpose(p, (2, 0, 1)))).float() for p in patches_hq_corrupted] + patches_lq = [torch.nn.functional.interpolate(p.unsqueeze(0), scale_factor=1/self.scale, mode='bilinear').squeeze() for p in patches_hq_corrupted] patches_lq = torch.stack(patches_lq, dim=0) d = {'LQ': patches_lq, 'GT': patches_hq, 'GT_path': full_path} @@ -118,10 +128,14 @@ def _build_multiscale_patch_index_map(depth, ind, node, leaves): if __name__ == '__main__': opt = { 'name': 'amalgam', - 'dataroot': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\images-half'], + 'paths': ['F:\\4k6k\\datasets\\images\\div2k\\DIV2K_train_HR'], 'num_scales': 4, 'scale': 2, - 'hq_tile_size': 128 + 'hq_tile_size': 128, + 'fixed_corruptions': ['jpeg'], + 'random_corruptions': ['gaussian_blur', 'motion-blur', 'noise-5'], + 'num_corrupts_per_image': 1, + 'corruption_blur_scale': 5 } import torchvision @@ -133,14 +147,15 @@ if __name__ == '__main__': quadrant=2 print(i) o = ds[random.randint(0, len(ds))] - k = 'HQ' - v = o['HQ'] - #for j, img in enumerate(v): - # torchvision.utils.save_image(img.unsqueeze(0), "debug/%i_%s_%i.png" % (i, k, j)) tree_ind = random.randint(0, len(multiscale_tree)) - node = multiscale_tree[tree_ind] - depth = 0 - while node is not None: - torchvision.utils.save_image(v[node.index].unsqueeze(0), "debug/%i_%s_%i.png" % (i, k, depth)) - depth += 1 - node = node.parent \ No newline at end of file + for k, v in o.items(): + if 'path' in k: + continue + depth = 0 + node = multiscale_tree[tree_ind] + #for j, img in enumerate(v): + # torchvision.utils.save_image(img.unsqueeze(0), "debug/%i_%s_%i.png" % (i, k, j)) + while node is not None: + torchvision.utils.save_image(v[node.index].unsqueeze(0), "debug/%i_%s_%i.png" % (i, k, depth)) + depth += 1 + node = node.parent \ No newline at end of file