From 57fc3f490c05809ffdc0360e923e6704250abf52 Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 17 Sep 2020 13:30:51 -0600 Subject: [PATCH] Add script for extracting image tiles with reference images --- ..._lmdb.py => extract_subimages_with_ref.py} | 177 ++++++++++++------ 1 file changed, 117 insertions(+), 60 deletions(-) rename codes/data_scripts/{extract_subimages_with_ref_lmdb.py => extract_subimages_with_ref.py} (68%) diff --git a/codes/data_scripts/extract_subimages_with_ref_lmdb.py b/codes/data_scripts/extract_subimages_with_ref.py similarity index 68% rename from codes/data_scripts/extract_subimages_with_ref_lmdb.py rename to codes/data_scripts/extract_subimages_with_ref.py index 0a99ac4b..50c3c4b8 100644 --- a/codes/data_scripts/extract_subimages_with_ref_lmdb.py +++ b/codes/data_scripts/extract_subimages_with_ref.py @@ -9,6 +9,7 @@ import lmdb import pyarrow import torch.utils.data as data from tqdm import tqdm +import torch def main(): @@ -19,15 +20,28 @@ def main(): opt['compression_level'] = 90 # JPEG compression quality rating. # CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and longer # compression time. If read raw images during training, use 0 for faster IO speed. + if mode == 'single': - opt['input_folder'] = 'F:\\4k6k\\datasets\\ns_images\\imagesets\\images' - opt['save_folder'] = 'F:\\4k6k\\datasets\\ns_images\\imagesets\\lmdb_with_ref' - opt['crop_sz'] = 512 # the size of each sub-image - opt['step'] = 128 # step of the sliding crop window + opt['dest'] = 'file' + opt['input_folder'] = 'F:\\4k6k\\datasets\\ns_images\\vixen\\full_video_segments' + opt['save_folder'] = 'F:\\4k6k\\datasets\\ns_images\\vixen\\full_video_with_refs' + opt['crop_sz'] = [256, 512, 1024] # the size of each sub-image + opt['step'] = 256 # step of the sliding crop window opt['thres_sz'] = 128 # size threshold - opt['resize_final_img'] = .5 + opt['resize_final_img'] = [1, .5, .25] opt['only_resize'] = False - extract_single(opt, split_img) + + save_folder = opt['save_folder'] + if not osp.exists(save_folder): + os.makedirs(save_folder) + print('mkdir [{:s}] ...'.format(save_folder)) + + if opt['dest'] == 'lmdb': + writer = LmdbWriter(save_folder) + else: + writer = FileWriter(save_folder) + + extract_single(opt, writer, split_img) elif mode == 'pair': GT_folder = '../../datasets/div2k/DIV2K_train_HR' LR_folder = '../../datasets/div2k/DIV2K_train_LR_bicubic/X4' @@ -91,7 +105,7 @@ class LmdbWriter: self.keys = [] # Writes the given reference image to the db and returns its ID. - def write_reference_image(self, ref_img): + def write_reference_image(self, ref_img, _): id = self.ref_id self.ref_id += 1 self.write_image(id, ref_img[0], ref_img[1]) @@ -123,6 +137,48 @@ class LmdbWriter: self.db.close() +class FileWriter: + def __init__(self, folder): + self.folder = folder + self.next_unique_id = 0 + self.ref_center_points = {} # Maps ref_img basename to a dict of image IDs:center points + self.ref_ids_to_names = {} + + def get_next_unique_id(self): + id = self.next_unique_id + self.next_unique_id += 1 + return id + + def save_image(self, ref_path, img_name, img): + save_path = osp.join(self.folder, ref_path) + os.makedirs(save_path, exist_ok=True) + f = open(osp.join(save_path, img_name), "wb") + f.write(img) + f.close() + + # Writes the given reference image to the db and returns its ID. + def write_reference_image(self, ref_img, path): + ref_img, _ = ref_img # Encoded with a center point, which is irrelevant for the reference image. + img_name = osp.basename(path).replace(".jpg", "").replace(".png", "") + self.ref_center_points[img_name] = {} + self.save_image(img_name, "ref.jpg", ref_img) + id = self.get_next_unique_id() + self.ref_ids_to_names[id] = img_name + return id + + # Writes a tile image to the db given a reference image and returns its ID. + def write_tile_image(self, ref_id, tile_image): + id = self.get_next_unique_id() + ref_name = self.ref_ids_to_names[ref_id] + img, center = tile_image + self.ref_center_points[ref_name][id] = center + self.save_image(ref_name, "%08i.jpg" % (id,), img) + return id + + def close(self): + for ref_name, cps in self.ref_center_points.items(): + torch.save(cps, osp.join(self.folder, ref_name, "centers.pt")) + class TiledDataset(data.Dataset): def __init__(self, opt, split_mode=False): self.split_mode = split_mode @@ -136,16 +192,48 @@ class TiledDataset(data.Dataset): else: return self.get(index, False, False) - def get(self, index, split_mode, left_img): - path = self.images[index] - crop_sz = self.opt['crop_sz'] + def get_for_scale(self, img, split_mode, left_img, crop_sz, resize_factor): step = self.opt['step'] thres_sz = self.opt['thres_sz'] - only_resize = self.opt['only_resize'] + + h, w, c = img.shape + if split_mode: + w = w/2 + + h_space = np.arange(0, h - crop_sz + 1, step) + if h - (h_space[-1] + crop_sz) > thres_sz: + h_space = np.append(h_space, h - crop_sz) + w_space = np.arange(0, w - crop_sz + 1, step) + if w - (w_space[-1] + crop_sz) > thres_sz: + w_space = np.append(w_space, w - crop_sz) + + index = 0 + tile_dim = int(crop_sz * resize_factor) + dsize = (tile_dim, tile_dim) + results = [] + for x in h_space: + for y in w_space: + index += 1 + crop_img = img[x:x + crop_sz, y:y + crop_sz, :] + center_point = (x + crop_sz // 2, y + crop_sz // 2) + crop_img = np.ascontiguousarray(crop_img) + if 'resize_final_img' in self.opt.keys(): + # Resize too. + center_point = (int(center_point[0] * resize_factor), int(center_point[1] * resize_factor)) + crop_img = cv2.resize(crop_img, dsize, interpolation=cv2.INTER_AREA) + success, buffer = cv2.imencode(".jpg", crop_img, [cv2.IMWRITE_JPEG_QUALITY, self.opt['compression_level']]) + assert success + results.append((buffer, center_point)) + return results + + def get(self, index, split_mode, left_img): + path = self.images[index] img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # We must convert the image into a square. Crop the image so that only the center is left, since this is often # the most salient part of the image. + if len(img.shape) == 2: # Greyscale not supported. + return None h, w, c = img.shape dim = min(h, w) img = img[(h - dim) // 2:dim + (h - dim) // 2, (w - dim) // 2:dim + (w - dim) // 2, :] @@ -153,7 +241,7 @@ class TiledDataset(data.Dataset): h, w, c = img.shape # Uncomment to filter any image that doesnt meet a threshold size. if min(h,w) < 1024: - return + return None left = 0 right = w if split_mode: @@ -163,48 +251,20 @@ class TiledDataset(data.Dataset): else: left = int(w/2) right = w - w = int(w/2) img = img[:, left:right] - h_space = np.arange(0, h - crop_sz + 1, step) - if h - (h_space[-1] + crop_sz) > thres_sz: - h_space = np.append(h_space, h - crop_sz) - w_space = np.arange(0, w - crop_sz + 1, step) - if w - (w_space[-1] + crop_sz) > thres_sz: - w_space = np.append(w_space, w - crop_sz) + tile_dim = int(self.opt['crop_sz'][0] * self.opt['resize_final_img'][0]) + dsize = (tile_dim, tile_dim) - dsize = None - if only_resize: - dsize = (crop_sz, crop_sz) - if h < w: - h_space = [0] - w_space = [(w - h) // 2] - crop_sz = h - else: - h_space = [(h - w) // 2] - w_space = [0] - crop_sz = w + # Reference image should always be first entry in results. + ref_img = cv2.resize(img, dsize, interpolation=cv2.INTER_AREA) + success, ref_buffer = cv2.imencode(".jpg", ref_img, [cv2.IMWRITE_JPEG_QUALITY, self.opt['compression_level']]) + assert success + results = [(ref_buffer, (-1,-1))] - index = 0 - resize_factor = self.opt['resize_final_img'] if 'resize_final_img' in self.opt.keys() else 1 - dsize = (int(crop_sz * resize_factor), int(crop_sz * resize_factor)) - # Reference image should always be first. - results = [(cv2.resize(img, dsize, interpolation=cv2.INTER_AREA), (-1,-1))] - for x in h_space: - for y in w_space: - index += 1 - crop_img = img[x:x + crop_sz, y:y + crop_sz, :] - center_point = (x + crop_sz // 2, y + crop_sz // 2) - crop_img = np.ascontiguousarray(crop_img) - if 'resize_final_img' in self.opt.keys(): - # Resize too. - resize_factor = self.opt['resize_final_img'] - center_point = (int(center_point[0] * resize_factor), int(center_point[1] * resize_factor)) - crop_img = cv2.resize(crop_img, dsize, interpolation=cv2.INTER_AREA) - success, buffer = cv2.imencode(".jpg", crop_img, [cv2.IMWRITE_JPEG_QUALITY, self.opt['compression_level']]) - assert success - results.append((buffer, center_point)) - return results + for crop_sz, resize_factor in zip(self.opt['crop_sz'], self.opt['resize_final_img']): + results.extend(self.get_for_scale(img, split_mode, left_img, crop_sz, resize_factor)) + return results, path def __len__(self): return len(self.images) @@ -213,23 +273,20 @@ class TiledDataset(data.Dataset): def identity(x): return x -def extract_single(opt, split_img=False): - save_folder = opt['save_folder'] - if not osp.exists(save_folder): - os.makedirs(save_folder) - print('mkdir [{:s}] ...'.format(save_folder)) - lmdb = LmdbWriter(save_folder) - +def extract_single(opt, writer, split_img=False): dataset = TiledDataset(opt, split_img) dataloader = data.DataLoader(dataset, num_workers=opt['n_thread'], collate_fn=identity) tq = tqdm(dataloader) for imgs in tq: + if imgs is None or imgs[0] is None: + continue + imgs, path = imgs[0] if imgs is None or len(imgs) <= 1: continue - ref_id = lmdb.write_reference_image(imgs[0]) + ref_id = writer.write_reference_image(imgs[0], path) for tile in imgs[1:]: - lmdb.write_tile_image(ref_id, tile) - lmdb.close() + writer.write_tile_image(ref_id, tile) + writer.close() if __name__ == '__main__':