Allow image corruption in multiscale dataset
This commit is contained in:
parent
668cafa798
commit
ffad0e0422
|
@ -10,6 +10,9 @@ import torchvision.transforms.functional as F
|
|||
|
||||
|
||||
# Reads full-quality images and pulls tiles at regular zoom intervals from them. Only usable for training purposes.
|
||||
from data.image_corruptor import ImageCorruptor
|
||||
|
||||
|
||||
class MultiScaleDataset(data.Dataset):
|
||||
def __init__(self, opt):
|
||||
super(MultiScaleDataset, self).__init__()
|
||||
|
@ -20,6 +23,7 @@ class MultiScaleDataset(data.Dataset):
|
|||
self.hq_size_cap = self.tile_size * 2 ** self.num_scales
|
||||
self.scale = self.opt['scale']
|
||||
self.paths_hq, self.sizes_hq = util.get_image_paths(self.data_type, opt['paths'], [1])
|
||||
self.corruptor = ImageCorruptor(opt)
|
||||
|
||||
# Selects the smallest dimension from the image and crops it randomly so the other dimension matches. The cropping
|
||||
# offset from center is chosen on a normal probability curve.
|
||||
|
@ -62,13 +66,19 @@ class MultiScaleDataset(data.Dataset):
|
|||
img_full = cv2.resize(img_full, (self.hq_size_cap, self.hq_size_cap), interpolation=cv2.INTER_AREA)
|
||||
patches_hq = [cv2.resize(img_full, (self.tile_size, self.tile_size), interpolation=cv2.INTER_AREA)]
|
||||
self.recursively_extract_patches(img_full, patches_hq, 1)
|
||||
# Image corruption is applied against the full size image for this dataset.
|
||||
img_corrupted = self.corruptor.corrupt_images([img_full])[0]
|
||||
patches_hq_corrupted = [cv2.resize(img_corrupted, (self.tile_size, self.tile_size), interpolation=cv2.INTER_AREA)]
|
||||
self.recursively_extract_patches(img_corrupted, patches_hq_corrupted, 1)
|
||||
|
||||
# BGR to RGB, HWC to CHW, numpy to tensor
|
||||
if patches_hq[0].shape[2] == 3:
|
||||
patches_hq = [cv2.cvtColor(p, cv2.COLOR_BGR2RGB) for p in patches_hq]
|
||||
patches_hq_corrupted = [cv2.cvtColor(p, cv2.COLOR_BGR2RGB) for p in patches_hq_corrupted]
|
||||
patches_hq = [torch.from_numpy(np.ascontiguousarray(np.transpose(p, (2, 0, 1)))).float() for p in patches_hq]
|
||||
patches_hq = torch.stack(patches_hq, dim=0)
|
||||
patches_lq = [torch.nn.functional.interpolate(p.unsqueeze(0), scale_factor=1/self.scale, mode='bilinear').squeeze() for p in patches_hq]
|
||||
patches_hq_corrupted = [torch.from_numpy(np.ascontiguousarray(np.transpose(p, (2, 0, 1)))).float() for p in patches_hq_corrupted]
|
||||
patches_lq = [torch.nn.functional.interpolate(p.unsqueeze(0), scale_factor=1/self.scale, mode='bilinear').squeeze() for p in patches_hq_corrupted]
|
||||
patches_lq = torch.stack(patches_lq, dim=0)
|
||||
|
||||
d = {'LQ': patches_lq, 'GT': patches_hq, 'GT_path': full_path}
|
||||
|
@ -118,10 +128,14 @@ def _build_multiscale_patch_index_map(depth, ind, node, leaves):
|
|||
if __name__ == '__main__':
|
||||
opt = {
|
||||
'name': 'amalgam',
|
||||
'dataroot': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\images-half'],
|
||||
'paths': ['F:\\4k6k\\datasets\\images\\div2k\\DIV2K_train_HR'],
|
||||
'num_scales': 4,
|
||||
'scale': 2,
|
||||
'hq_tile_size': 128
|
||||
'hq_tile_size': 128,
|
||||
'fixed_corruptions': ['jpeg'],
|
||||
'random_corruptions': ['gaussian_blur', 'motion-blur', 'noise-5'],
|
||||
'num_corrupts_per_image': 1,
|
||||
'corruption_blur_scale': 5
|
||||
}
|
||||
|
||||
import torchvision
|
||||
|
@ -133,14 +147,15 @@ if __name__ == '__main__':
|
|||
quadrant=2
|
||||
print(i)
|
||||
o = ds[random.randint(0, len(ds))]
|
||||
k = 'HQ'
|
||||
v = o['HQ']
|
||||
#for j, img in enumerate(v):
|
||||
# torchvision.utils.save_image(img.unsqueeze(0), "debug/%i_%s_%i.png" % (i, k, j))
|
||||
tree_ind = random.randint(0, len(multiscale_tree))
|
||||
node = multiscale_tree[tree_ind]
|
||||
depth = 0
|
||||
while node is not None:
|
||||
torchvision.utils.save_image(v[node.index].unsqueeze(0), "debug/%i_%s_%i.png" % (i, k, depth))
|
||||
depth += 1
|
||||
node = node.parent
|
||||
for k, v in o.items():
|
||||
if 'path' in k:
|
||||
continue
|
||||
depth = 0
|
||||
node = multiscale_tree[tree_ind]
|
||||
#for j, img in enumerate(v):
|
||||
# torchvision.utils.save_image(img.unsqueeze(0), "debug/%i_%s_%i.png" % (i, k, j))
|
||||
while node is not None:
|
||||
torchvision.utils.save_image(v[node.index].unsqueeze(0), "debug/%i_%s_%i.png" % (i, k, depth))
|
||||
depth += 1
|
||||
node = node.parent
|
Loading…
Reference in New Issue
Block a user