Allow image corruption in multiscale dataset
This commit is contained in:
parent
668cafa798
commit
ffad0e0422
|
@ -10,6 +10,9 @@ import torchvision.transforms.functional as F
|
||||||
|
|
||||||
|
|
||||||
# Reads full-quality images and pulls tiles at regular zoom intervals from them. Only usable for training purposes.
|
# Reads full-quality images and pulls tiles at regular zoom intervals from them. Only usable for training purposes.
|
||||||
|
from data.image_corruptor import ImageCorruptor
|
||||||
|
|
||||||
|
|
||||||
class MultiScaleDataset(data.Dataset):
|
class MultiScaleDataset(data.Dataset):
|
||||||
def __init__(self, opt):
|
def __init__(self, opt):
|
||||||
super(MultiScaleDataset, self).__init__()
|
super(MultiScaleDataset, self).__init__()
|
||||||
|
@ -20,6 +23,7 @@ class MultiScaleDataset(data.Dataset):
|
||||||
self.hq_size_cap = self.tile_size * 2 ** self.num_scales
|
self.hq_size_cap = self.tile_size * 2 ** self.num_scales
|
||||||
self.scale = self.opt['scale']
|
self.scale = self.opt['scale']
|
||||||
self.paths_hq, self.sizes_hq = util.get_image_paths(self.data_type, opt['paths'], [1])
|
self.paths_hq, self.sizes_hq = util.get_image_paths(self.data_type, opt['paths'], [1])
|
||||||
|
self.corruptor = ImageCorruptor(opt)
|
||||||
|
|
||||||
# Selects the smallest dimension from the image and crops it randomly so the other dimension matches. The cropping
|
# Selects the smallest dimension from the image and crops it randomly so the other dimension matches. The cropping
|
||||||
# offset from center is chosen on a normal probability curve.
|
# offset from center is chosen on a normal probability curve.
|
||||||
|
@ -62,13 +66,19 @@ class MultiScaleDataset(data.Dataset):
|
||||||
img_full = cv2.resize(img_full, (self.hq_size_cap, self.hq_size_cap), interpolation=cv2.INTER_AREA)
|
img_full = cv2.resize(img_full, (self.hq_size_cap, self.hq_size_cap), interpolation=cv2.INTER_AREA)
|
||||||
patches_hq = [cv2.resize(img_full, (self.tile_size, self.tile_size), interpolation=cv2.INTER_AREA)]
|
patches_hq = [cv2.resize(img_full, (self.tile_size, self.tile_size), interpolation=cv2.INTER_AREA)]
|
||||||
self.recursively_extract_patches(img_full, patches_hq, 1)
|
self.recursively_extract_patches(img_full, patches_hq, 1)
|
||||||
|
# Image corruption is applied against the full size image for this dataset.
|
||||||
|
img_corrupted = self.corruptor.corrupt_images([img_full])[0]
|
||||||
|
patches_hq_corrupted = [cv2.resize(img_corrupted, (self.tile_size, self.tile_size), interpolation=cv2.INTER_AREA)]
|
||||||
|
self.recursively_extract_patches(img_corrupted, patches_hq_corrupted, 1)
|
||||||
|
|
||||||
# BGR to RGB, HWC to CHW, numpy to tensor
|
# BGR to RGB, HWC to CHW, numpy to tensor
|
||||||
if patches_hq[0].shape[2] == 3:
|
if patches_hq[0].shape[2] == 3:
|
||||||
patches_hq = [cv2.cvtColor(p, cv2.COLOR_BGR2RGB) for p in patches_hq]
|
patches_hq = [cv2.cvtColor(p, cv2.COLOR_BGR2RGB) for p in patches_hq]
|
||||||
|
patches_hq_corrupted = [cv2.cvtColor(p, cv2.COLOR_BGR2RGB) for p in patches_hq_corrupted]
|
||||||
patches_hq = [torch.from_numpy(np.ascontiguousarray(np.transpose(p, (2, 0, 1)))).float() for p in patches_hq]
|
patches_hq = [torch.from_numpy(np.ascontiguousarray(np.transpose(p, (2, 0, 1)))).float() for p in patches_hq]
|
||||||
patches_hq = torch.stack(patches_hq, dim=0)
|
patches_hq = torch.stack(patches_hq, dim=0)
|
||||||
patches_lq = [torch.nn.functional.interpolate(p.unsqueeze(0), scale_factor=1/self.scale, mode='bilinear').squeeze() for p in patches_hq]
|
patches_hq_corrupted = [torch.from_numpy(np.ascontiguousarray(np.transpose(p, (2, 0, 1)))).float() for p in patches_hq_corrupted]
|
||||||
|
patches_lq = [torch.nn.functional.interpolate(p.unsqueeze(0), scale_factor=1/self.scale, mode='bilinear').squeeze() for p in patches_hq_corrupted]
|
||||||
patches_lq = torch.stack(patches_lq, dim=0)
|
patches_lq = torch.stack(patches_lq, dim=0)
|
||||||
|
|
||||||
d = {'LQ': patches_lq, 'GT': patches_hq, 'GT_path': full_path}
|
d = {'LQ': patches_lq, 'GT': patches_hq, 'GT_path': full_path}
|
||||||
|
@ -118,10 +128,14 @@ def _build_multiscale_patch_index_map(depth, ind, node, leaves):
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
opt = {
|
opt = {
|
||||||
'name': 'amalgam',
|
'name': 'amalgam',
|
||||||
'dataroot': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\images-half'],
|
'paths': ['F:\\4k6k\\datasets\\images\\div2k\\DIV2K_train_HR'],
|
||||||
'num_scales': 4,
|
'num_scales': 4,
|
||||||
'scale': 2,
|
'scale': 2,
|
||||||
'hq_tile_size': 128
|
'hq_tile_size': 128,
|
||||||
|
'fixed_corruptions': ['jpeg'],
|
||||||
|
'random_corruptions': ['gaussian_blur', 'motion-blur', 'noise-5'],
|
||||||
|
'num_corrupts_per_image': 1,
|
||||||
|
'corruption_blur_scale': 5
|
||||||
}
|
}
|
||||||
|
|
||||||
import torchvision
|
import torchvision
|
||||||
|
@ -133,13 +147,14 @@ if __name__ == '__main__':
|
||||||
quadrant=2
|
quadrant=2
|
||||||
print(i)
|
print(i)
|
||||||
o = ds[random.randint(0, len(ds))]
|
o = ds[random.randint(0, len(ds))]
|
||||||
k = 'HQ'
|
tree_ind = random.randint(0, len(multiscale_tree))
|
||||||
v = o['HQ']
|
for k, v in o.items():
|
||||||
|
if 'path' in k:
|
||||||
|
continue
|
||||||
|
depth = 0
|
||||||
|
node = multiscale_tree[tree_ind]
|
||||||
#for j, img in enumerate(v):
|
#for j, img in enumerate(v):
|
||||||
# torchvision.utils.save_image(img.unsqueeze(0), "debug/%i_%s_%i.png" % (i, k, j))
|
# torchvision.utils.save_image(img.unsqueeze(0), "debug/%i_%s_%i.png" % (i, k, j))
|
||||||
tree_ind = random.randint(0, len(multiscale_tree))
|
|
||||||
node = multiscale_tree[tree_ind]
|
|
||||||
depth = 0
|
|
||||||
while node is not None:
|
while node is not None:
|
||||||
torchvision.utils.save_image(v[node.index].unsqueeze(0), "debug/%i_%s_%i.png" % (i, k, depth))
|
torchvision.utils.save_image(v[node.index].unsqueeze(0), "debug/%i_%s_%i.png" % (i, k, depth))
|
||||||
depth += 1
|
depth += 1
|
||||||
|
|
Loading…
Reference in New Issue
Block a user