From 1ba01d69b5b87bbdc28e627161ad7e1dbf12ad17 Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 15 Oct 2020 17:18:23 -0600 Subject: [PATCH] Move datasets to INTER_AREA interpolation for downsizing Looks **FAR** better visually --- codes/data/base_unsupervised_image_dataset.py | 12 ++-- codes/data/multiscale_dataset.py | 59 +++++++++++-------- 2 files changed, 41 insertions(+), 30 deletions(-) diff --git a/codes/data/base_unsupervised_image_dataset.py b/codes/data/base_unsupervised_image_dataset.py index 7b7fe10a..a00209f6 100644 --- a/codes/data/base_unsupervised_image_dataset.py +++ b/codes/data/base_unsupervised_image_dataset.py @@ -68,9 +68,9 @@ class BaseUnsupervisedImageDataset(data.Dataset): for hq, hq_ref, hq_mask, hq_center in zip(imgs_hq, refs_hq, masks_hq, centers_hq): # It is assumed that the target size is a square. target_size = (self.target_hq_size, self.target_hq_size) - hqs_adjusted.append(cv2.resize(hq, target_size, interpolation=cv2.INTER_LINEAR)) - hq_refs_adjusted.append(cv2.resize(hq_ref, target_size, interpolation=cv2.INTER_LINEAR)) - hq_masks_adjusted.append(cv2.resize(hq_mask, target_size, interpolation=cv2.INTER_LINEAR)) + hqs_adjusted.append(cv2.resize(hq, target_size, interpolation=cv2.INTER_AREA)) + hq_refs_adjusted.append(cv2.resize(hq_ref, target_size, interpolation=cv2.INTER_AREA)) + hq_masks_adjusted.append(cv2.resize(hq_mask, target_size, interpolation=cv2.INTER_AREA)) hq_centers_adjusted.append(self.resize_point(hq_center, (h, w), target_size)) h, w = self.target_hq_size, self.target_hq_size else: @@ -97,9 +97,9 @@ class BaseUnsupervisedImageDataset(data.Dataset): lms.append(hq_mask) lcs.append(hq_center) else: - ls.append(cv2.resize(hq, (h // self.scale, w // self.scale), interpolation=cv2.INTER_LINEAR)) - lrs.append(cv2.resize(hq_ref, (h // self.scale, w // self.scale), interpolation=cv2.INTER_LINEAR)) - lms.append(cv2.resize(hq_mask, (h // self.scale, w // self.scale), interpolation=cv2.INTER_LINEAR)) + ls.append(cv2.resize(hq, (h // self.scale, w // self.scale), interpolation=cv2.INTER_AREA)) + lrs.append(cv2.resize(hq_ref, (h // self.scale, w // self.scale), interpolation=cv2.INTER_AREA)) + lms.append(cv2.resize(hq_mask, (h // self.scale, w // self.scale), interpolation=cv2.INTER_AREA)) lcs.append(self.resize_point(hq_center, (h, w), ls[0].shape[:2])) # Corrupt the LQ image (only in eval mode) if not self.for_eval: diff --git a/codes/data/multiscale_dataset.py b/codes/data/multiscale_dataset.py index 97c7f167..4a51528c 100644 --- a/codes/data/multiscale_dataset.py +++ b/codes/data/multiscale_dataset.py @@ -40,7 +40,7 @@ class MultiScaleDataset(data.Dataset): return image[:, left:left+h, :] def recursively_extract_patches(self, input_img, result_list, depth): - if depth > self.num_scales: + if depth >= self.num_scales: return patch_size = self.hq_size_cap // (2 ** depth) # First pull the four sub-patches. @@ -48,7 +48,7 @@ class MultiScaleDataset(data.Dataset): input_img[:patch_size, patch_size:], input_img[patch_size:, :patch_size], input_img[patch_size:, patch_size:]] - result_list.extend([cv2.resize(p, (self.tile_size, self.tile_size), interpolation=cv2.INTER_LINEAR) for p in patches]) + result_list.extend([cv2.resize(p, (self.tile_size, self.tile_size), interpolation=cv2.INTER_AREA) for p in patches]) for p in patches: self.recursively_extract_patches(p, result_list, depth+1) @@ -59,8 +59,8 @@ class MultiScaleDataset(data.Dataset): img_full = util.channel_convert(img_full.shape[2], 'RGB', [img_full])[0] img_full = util.augment([img_full], True, True)[0] img_full = self.get_square_image(img_full) - img_full = cv2.resize(img_full, (self.hq_size_cap, self.hq_size_cap), interpolation=cv2.INTER_LINEAR) - patches_hq = [cv2.resize(img_full, (self.tile_size, self.tile_size), interpolation=cv2.INTER_LINEAR)] + img_full = cv2.resize(img_full, (self.hq_size_cap, self.hq_size_cap), interpolation=cv2.INTER_AREA) + patches_hq = [cv2.resize(img_full, (self.tile_size, self.tile_size), interpolation=cv2.INTER_AREA)] self.recursively_extract_patches(img_full, patches_hq, 1) # BGR to RGB, HWC to CHW, numpy to tensor @@ -75,24 +75,34 @@ class MultiScaleDataset(data.Dataset): def __len__(self): return len(self.paths_hq) +class MultiscaleTreeNode: + def __init__(self, index, parent): + self.index = index + self.parent = parent + self.children = [] + + def add_child(self, child): + self.children.append(child) + return child + def build_multiscale_patch_index_map(depth): if depth < 0: return - recursive_list = [] - map = (0, recursive_list) - _build_multiscale_patch_index_map(depth, 1, recursive_list) - return map + root = MultiscaleTreeNode(0, None) + leaves = [] + _build_multiscale_patch_index_map(depth-1, 1, root, leaves) + return leaves -def _build_multiscale_patch_index_map(depth, ind, recursive_list): - if depth <= 0: - return ind - patches = [(ind+i, []) for i in range(4)] - recursive_list.extend(patches) +def _build_multiscale_patch_index_map(depth, ind, node, leaves): + subnodes = [node.add_child(MultiscaleTreeNode(ind+i, node)) for i in range(4)] ind += 4 - for _, p in patches: - ind = _build_multiscale_patch_index_map(depth-1, ind, p) + if depth == 1: + leaves.extend(subnodes) + else: + for n in subnodes: + ind = _build_multiscale_patch_index_map(depth-1, ind, n, leaves) return ind @@ -109,18 +119,19 @@ if __name__ == '__main__': ds = MultiScaleDataset(opt) import os os.makedirs("debug", exist_ok=True) - multiscale_map = build_multiscale_patch_index_map(4) - for i in range(900, len(ds)): + multiscale_tree = build_multiscale_patch_index_map(4) + for i in range(500, len(ds)): quadrant=2 print(i) - o = ds[i] + o = ds[random.randint(0, len(ds))] k = 'HQ' v = o['HQ'] #for j, img in enumerate(v): # torchvision.utils.save_image(img.unsqueeze(0), "debug/%i_%s_%i.png" % (i, k, j)) - torchvision.utils.save_image(v[0].unsqueeze(0), "debug/%i_%s_0.png" % (i, k)) - map_tuple = multiscale_map[1][quadrant] - while map_tuple[1]: - ind = map_tuple[0] - torchvision.utils.save_image(v[ind].unsqueeze(0), "debug/%i_%s_%i.png" % (i, k, ind+1)) - map_tuple = map_tuple[1][quadrant] \ No newline at end of file + tree_ind = random.randint(0, len(multiscale_tree)) + node = multiscale_tree[tree_ind] + depth = 0 + while node is not None: + torchvision.utils.save_image(v[node.index].unsqueeze(0), "debug/%i_%s_%i.png" % (i, k, depth)) + depth += 1 + node = node.parent \ No newline at end of file