Allow image corruption in multiscale dataset

This commit is contained in:
James Betker 2020-10-19 10:10:27 -06:00
parent 668cafa798
commit ffad0e0422

View File

@ -10,6 +10,9 @@ import torchvision.transforms.functional as F
# Reads full-quality images and pulls tiles at regular zoom intervals from them. Only usable for training purposes. # Reads full-quality images and pulls tiles at regular zoom intervals from them. Only usable for training purposes.
from data.image_corruptor import ImageCorruptor
class MultiScaleDataset(data.Dataset): class MultiScaleDataset(data.Dataset):
def __init__(self, opt): def __init__(self, opt):
super(MultiScaleDataset, self).__init__() super(MultiScaleDataset, self).__init__()
@ -20,6 +23,7 @@ class MultiScaleDataset(data.Dataset):
self.hq_size_cap = self.tile_size * 2 ** self.num_scales self.hq_size_cap = self.tile_size * 2 ** self.num_scales
self.scale = self.opt['scale'] self.scale = self.opt['scale']
self.paths_hq, self.sizes_hq = util.get_image_paths(self.data_type, opt['paths'], [1]) self.paths_hq, self.sizes_hq = util.get_image_paths(self.data_type, opt['paths'], [1])
self.corruptor = ImageCorruptor(opt)
# Selects the smallest dimension from the image and crops it randomly so the other dimension matches. The cropping # Selects the smallest dimension from the image and crops it randomly so the other dimension matches. The cropping
# offset from center is chosen on a normal probability curve. # offset from center is chosen on a normal probability curve.
@ -62,13 +66,19 @@ class MultiScaleDataset(data.Dataset):
img_full = cv2.resize(img_full, (self.hq_size_cap, self.hq_size_cap), interpolation=cv2.INTER_AREA) img_full = cv2.resize(img_full, (self.hq_size_cap, self.hq_size_cap), interpolation=cv2.INTER_AREA)
patches_hq = [cv2.resize(img_full, (self.tile_size, self.tile_size), interpolation=cv2.INTER_AREA)] patches_hq = [cv2.resize(img_full, (self.tile_size, self.tile_size), interpolation=cv2.INTER_AREA)]
self.recursively_extract_patches(img_full, patches_hq, 1) self.recursively_extract_patches(img_full, patches_hq, 1)
# Image corruption is applied against the full size image for this dataset.
img_corrupted = self.corruptor.corrupt_images([img_full])[0]
patches_hq_corrupted = [cv2.resize(img_corrupted, (self.tile_size, self.tile_size), interpolation=cv2.INTER_AREA)]
self.recursively_extract_patches(img_corrupted, patches_hq_corrupted, 1)
# BGR to RGB, HWC to CHW, numpy to tensor # BGR to RGB, HWC to CHW, numpy to tensor
if patches_hq[0].shape[2] == 3: if patches_hq[0].shape[2] == 3:
patches_hq = [cv2.cvtColor(p, cv2.COLOR_BGR2RGB) for p in patches_hq] patches_hq = [cv2.cvtColor(p, cv2.COLOR_BGR2RGB) for p in patches_hq]
patches_hq_corrupted = [cv2.cvtColor(p, cv2.COLOR_BGR2RGB) for p in patches_hq_corrupted]
patches_hq = [torch.from_numpy(np.ascontiguousarray(np.transpose(p, (2, 0, 1)))).float() for p in patches_hq] patches_hq = [torch.from_numpy(np.ascontiguousarray(np.transpose(p, (2, 0, 1)))).float() for p in patches_hq]
patches_hq = torch.stack(patches_hq, dim=0) patches_hq = torch.stack(patches_hq, dim=0)
patches_lq = [torch.nn.functional.interpolate(p.unsqueeze(0), scale_factor=1/self.scale, mode='bilinear').squeeze() for p in patches_hq] patches_hq_corrupted = [torch.from_numpy(np.ascontiguousarray(np.transpose(p, (2, 0, 1)))).float() for p in patches_hq_corrupted]
patches_lq = [torch.nn.functional.interpolate(p.unsqueeze(0), scale_factor=1/self.scale, mode='bilinear').squeeze() for p in patches_hq_corrupted]
patches_lq = torch.stack(patches_lq, dim=0) patches_lq = torch.stack(patches_lq, dim=0)
d = {'LQ': patches_lq, 'GT': patches_hq, 'GT_path': full_path} d = {'LQ': patches_lq, 'GT': patches_hq, 'GT_path': full_path}
@ -118,10 +128,14 @@ def _build_multiscale_patch_index_map(depth, ind, node, leaves):
if __name__ == '__main__': if __name__ == '__main__':
opt = { opt = {
'name': 'amalgam', 'name': 'amalgam',
'dataroot': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\images-half'], 'paths': ['F:\\4k6k\\datasets\\images\\div2k\\DIV2K_train_HR'],
'num_scales': 4, 'num_scales': 4,
'scale': 2, 'scale': 2,
'hq_tile_size': 128 'hq_tile_size': 128,
'fixed_corruptions': ['jpeg'],
'random_corruptions': ['gaussian_blur', 'motion-blur', 'noise-5'],
'num_corrupts_per_image': 1,
'corruption_blur_scale': 5
} }
import torchvision import torchvision
@ -133,14 +147,15 @@ if __name__ == '__main__':
quadrant=2 quadrant=2
print(i) print(i)
o = ds[random.randint(0, len(ds))] o = ds[random.randint(0, len(ds))]
k = 'HQ'
v = o['HQ']
#for j, img in enumerate(v):
# torchvision.utils.save_image(img.unsqueeze(0), "debug/%i_%s_%i.png" % (i, k, j))
tree_ind = random.randint(0, len(multiscale_tree)) tree_ind = random.randint(0, len(multiscale_tree))
node = multiscale_tree[tree_ind] for k, v in o.items():
depth = 0 if 'path' in k:
while node is not None: continue
torchvision.utils.save_image(v[node.index].unsqueeze(0), "debug/%i_%s_%i.png" % (i, k, depth)) depth = 0
depth += 1 node = multiscale_tree[tree_ind]
node = node.parent #for j, img in enumerate(v):
# torchvision.utils.save_image(img.unsqueeze(0), "debug/%i_%s_%i.png" % (i, k, j))
while node is not None:
torchvision.utils.save_image(v[node.index].unsqueeze(0), "debug/%i_%s_%i.png" % (i, k, depth))
depth += 1
node = node.parent