"""Create lmdb files for [General images (291 images/DIV2K) | Vimeo90K | REDS] training datasets""" import sys import os.path as osp import glob import pickle from multiprocessing import Pool import numpy as np import lmdb import cv2 sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__)))) import data.util as data_util # noqa: E402 import utils.util as util # noqa: E402 def main(): dataset = 'DIV2K_demo' # vimeo90K | REDS | general (e.g., DIV2K, 291) | DIV2K_demo |test mode = 'GT' # used for vimeo90k and REDS datasets # vimeo90k: GT | LR | flow # REDS: train_sharp, train_sharp_bicubic, train_blur_bicubic, train_blur, train_blur_comp # train_sharp_flowx4 if dataset == 'vimeo90k': vimeo90k(mode) elif dataset == 'REDS': REDS(mode) elif dataset == 'general': opt = {} opt['img_folder'] = '../../datasets/DIV2K/DIV2K800_sub' opt['lmdb_save_path'] = '../../datasets/DIV2K/DIV2K800_sub.lmdb' opt['name'] = 'DIV2K800_sub_GT' general_image_folder(opt) elif dataset == 'DIV2K_demo': opt = {} ## GT opt['img_folder'] = '../../datasets/DIV2K/DIV2K800_sub' opt['lmdb_save_path'] = '../../datasets/DIV2K/DIV2K800_sub.lmdb' opt['name'] = 'DIV2K800_sub_GT' general_image_folder(opt) ## LR opt['img_folder'] = '../../datasets/DIV2K/DIV2K800_sub_bicLRx4' opt['lmdb_save_path'] = '../../datasets/DIV2K/DIV2K800_sub_bicLRx4.lmdb' opt['name'] = 'DIV2K800_sub_bicLRx4' general_image_folder(opt) elif dataset == 'test': test_lmdb('../../datasets/REDS/train_sharp_wval.lmdb', 'REDS') def read_image_worker(path, key): img = cv2.imread(path, cv2.IMREAD_UNCHANGED) return (key, img) def general_image_folder(opt): """Create lmdb for general image folders Users should define the keys, such as: '0321_s035' for DIV2K sub-images If all the images have the same resolution, it will only store one copy of resolution info. Otherwise, it will store every resolution info. """ #### configurations read_all_imgs = False # whether real all images to memory with multiprocessing # Set False for use limited memory BATCH = 5000 # After BATCH images, lmdb commits, if read_all_imgs = False n_thread = 40 ######################################################## img_folder = opt['img_folder'] lmdb_save_path = opt['lmdb_save_path'] meta_info = {'name': opt['name']} if not lmdb_save_path.endswith('.lmdb'): raise ValueError("lmdb_save_path must end with \'lmdb\'.") if osp.exists(lmdb_save_path): print('Folder [{:s}] already exists. Exit...'.format(lmdb_save_path)) sys.exit(1) #### read all the image paths to a list print('Reading image path list ...') all_img_list = sorted(glob.glob(osp.join(img_folder, '*'))) keys = [] for img_path in all_img_list: keys.append(osp.splitext(osp.basename(img_path))[0]) if read_all_imgs: #### read all images to memory (multiprocessing) dataset = {} # store all image data. list cannot keep the order, use dict print('Read images with multiprocessing, #thread: {} ...'.format(n_thread)) pbar = util.ProgressBar(len(all_img_list)) def mycallback(arg): '''get the image data and update pbar''' key = arg[0] dataset[key] = arg[1] pbar.update('Reading {}'.format(key)) pool = Pool(n_thread) for path, key in zip(all_img_list, keys): pool.apply_async(read_image_worker, args=(path, key), callback=mycallback) pool.close() pool.join() print('Finish reading {} images.\nWrite lmdb...'.format(len(all_img_list))) #### create lmdb environment data_size_per_img = cv2.imread(all_img_list[0], cv2.IMREAD_UNCHANGED).nbytes print('data size per image is: ', data_size_per_img) data_size = data_size_per_img * len(all_img_list) env = lmdb.open(lmdb_save_path, map_size=data_size * 10) #### write data to lmdb pbar = util.ProgressBar(len(all_img_list)) txn = env.begin(write=True) resolutions = [] for idx, (path, key) in enumerate(zip(all_img_list, keys)): pbar.update('Write {}'.format(key)) key_byte = key.encode('ascii') data = dataset[key] if read_all_imgs else cv2.imread(path, cv2.IMREAD_UNCHANGED) if data.ndim == 2: H, W = data.shape C = 1 else: H, W, C = data.shape txn.put(key_byte, data) resolutions.append('{:d}_{:d}_{:d}'.format(C, H, W)) if not read_all_imgs and idx % BATCH == 0: txn.commit() txn = env.begin(write=True) txn.commit() env.close() print('Finish writing lmdb.') #### create meta information # check whether all the images are the same size assert len(keys) == len(resolutions) if len(set(resolutions)) <= 1: meta_info['resolution'] = [resolutions[0]] meta_info['keys'] = keys print('All images have the same resolution. Simplify the meta info.') else: meta_info['resolution'] = resolutions meta_info['keys'] = keys print('Not all images have the same resolution. Save meta info for each image.') pickle.dump(meta_info, open(osp.join(lmdb_save_path, 'meta_info.pkl'), "wb")) print('Finish creating lmdb meta info.') def vimeo90k(mode): """Create lmdb for the Vimeo90K dataset, each image with a fixed size GT: [3, 256, 448] Now only need the 4th frame, e.g., 00001_0001_4 LR: [3, 64, 112] 1st - 7th frames, e.g., 00001_0001_1, ..., 00001_0001_7 key: Use the folder and subfolder names, w/o the frame index, e.g., 00001_0001 flow: downsampled flow: [3, 360, 320], keys: 00001_0001_4_[p3, p2, p1, n1, n2, n3] Each flow is calculated with GT images by PWCNet and then downsampled by 1/4 Flow map is quantized by mmcv and saved in png format """ #### configurations read_all_imgs = False # whether real all images to memory with multiprocessing # Set False for use limited memory BATCH = 5000 # After BATCH images, lmdb commits, if read_all_imgs = False if mode == 'GT': img_folder = '../../datasets/vimeo90k/vimeo_septuplet/sequences' lmdb_save_path = '../../datasets/vimeo90k/vimeo90k_train_GT.lmdb' txt_file = '../../datasets/vimeo90k/vimeo_septuplet/sep_trainlist.txt' H_dst, W_dst = 256, 448 elif mode == 'LR': img_folder = '../../datasets/vimeo90k/vimeo_septuplet_matlabLRx4/sequences' lmdb_save_path = '../../datasets/vimeo90k/vimeo90k_train_LR7frames.lmdb' txt_file = '../../datasets/vimeo90k/vimeo_septuplet/sep_trainlist.txt' H_dst, W_dst = 64, 112 elif mode == 'flow': img_folder = '../../datasets/vimeo90k/vimeo_septuplet/sequences_flowx4' lmdb_save_path = '../../datasets/vimeo90k/vimeo90k_train_flowx4.lmdb' txt_file = '../../datasets/vimeo90k/vimeo_septuplet/sep_trainlist.txt' H_dst, W_dst = 128, 112 else: raise ValueError('Wrong dataset mode: {}'.format(mode)) n_thread = 40 ######################################################## if not lmdb_save_path.endswith('.lmdb'): raise ValueError("lmdb_save_path must end with \'lmdb\'.") if osp.exists(lmdb_save_path): print('Folder [{:s}] already exists. Exit...'.format(lmdb_save_path)) sys.exit(1) #### read all the image paths to a list print('Reading image path list ...') with open(txt_file) as f: train_l = f.readlines() train_l = [v.strip() for v in train_l] all_img_list = [] keys = [] for line in train_l: folder = line.split('/')[0] sub_folder = line.split('/')[1] all_img_list.extend(glob.glob(osp.join(img_folder, folder, sub_folder, '*'))) if mode == 'flow': for j in range(1, 4): keys.append('{}_{}_4_n{}'.format(folder, sub_folder, j)) keys.append('{}_{}_4_p{}'.format(folder, sub_folder, j)) else: for j in range(7): keys.append('{}_{}_{}'.format(folder, sub_folder, j + 1)) all_img_list = sorted(all_img_list) keys = sorted(keys) if mode == 'GT': # only read the 4th frame for the GT mode print('Only keep the 4th frame.') all_img_list = [v for v in all_img_list if v.endswith('im4.png')] keys = [v for v in keys if v.endswith('_4')] if read_all_imgs: #### read all images to memory (multiprocessing) dataset = {} # store all image data. list cannot keep the order, use dict print('Read images with multiprocessing, #thread: {} ...'.format(n_thread)) pbar = util.ProgressBar(len(all_img_list)) def mycallback(arg): """get the image data and update pbar""" key = arg[0] dataset[key] = arg[1] pbar.update('Reading {}'.format(key)) pool = Pool(n_thread) for path, key in zip(all_img_list, keys): pool.apply_async(read_image_worker, args=(path, key), callback=mycallback) pool.close() pool.join() print('Finish reading {} images.\nWrite lmdb...'.format(len(all_img_list))) #### write data to lmdb data_size_per_img = cv2.imread(all_img_list[0], cv2.IMREAD_UNCHANGED).nbytes print('data size per image is: ', data_size_per_img) data_size = data_size_per_img * len(all_img_list) env = lmdb.open(lmdb_save_path, map_size=data_size * 10) txn = env.begin(write=True) pbar = util.ProgressBar(len(all_img_list)) for idx, (path, key) in enumerate(zip(all_img_list, keys)): pbar.update('Write {}'.format(key)) key_byte = key.encode('ascii') data = dataset[key] if read_all_imgs else cv2.imread(path, cv2.IMREAD_UNCHANGED) if 'flow' in mode: H, W = data.shape assert H == H_dst and W == W_dst, 'different shape.' else: H, W, C = data.shape assert H == H_dst and W == W_dst and C == 3, 'different shape.' txn.put(key_byte, data) if not read_all_imgs and idx % BATCH == 0: txn.commit() txn = env.begin(write=True) txn.commit() env.close() print('Finish writing lmdb.') #### create meta information meta_info = {} if mode == 'GT': meta_info['name'] = 'Vimeo90K_train_GT' elif mode == 'LR': meta_info['name'] = 'Vimeo90K_train_LR' elif mode == 'flow': meta_info['name'] = 'Vimeo90K_train_flowx4' channel = 1 if 'flow' in mode else 3 meta_info['resolution'] = '{}_{}_{}'.format(channel, H_dst, W_dst) key_set = set() for key in keys: if mode == 'flow': a, b, _, _ = key.split('_') else: a, b, _ = key.split('_') key_set.add('{}_{}'.format(a, b)) meta_info['keys'] = list(key_set) pickle.dump(meta_info, open(osp.join(lmdb_save_path, 'meta_info.pkl'), "wb")) print('Finish creating lmdb meta info.') def REDS(mode): """Create lmdb for the REDS dataset, each image with a fixed size GT: [3, 720, 1280], key: 000_00000000 LR: [3, 180, 320], key: 000_00000000 key: 000_00000000 flow: downsampled flow: [3, 360, 320], keys: 000_00000005_[p2, p1, n1, n2] Each flow is calculated with the GT images by PWCNet and then downsampled by 1/4 Flow map is quantized by mmcv and saved in png format """ #### configurations read_all_imgs = False # whether real all images to memory with multiprocessing # Set False for use limited memory BATCH = 5000 # After BATCH images, lmdb commits, if read_all_imgs = False if mode == 'train_sharp': img_folder = '../../datasets/REDS/train_sharp' lmdb_save_path = '../../datasets/REDS/train_sharp_wval.lmdb' H_dst, W_dst = 720, 1280 elif mode == 'train_sharp_bicubic': img_folder = '../../datasets/REDS/train_sharp_bicubic' lmdb_save_path = '../../datasets/REDS/train_sharp_bicubic_wval.lmdb' H_dst, W_dst = 180, 320 elif mode == 'train_blur_bicubic': img_folder = '../../datasets/REDS/train_blur_bicubic' lmdb_save_path = '../../datasets/REDS/train_blur_bicubic_wval.lmdb' H_dst, W_dst = 180, 320 elif mode == 'train_blur': img_folder = '../../datasets/REDS/train_blur' lmdb_save_path = '../../datasets/REDS/train_blur_wval.lmdb' H_dst, W_dst = 720, 1280 elif mode == 'train_blur_comp': img_folder = '../../datasets/REDS/train_blur_comp' lmdb_save_path = '../../datasets/REDS/train_blur_comp_wval.lmdb' H_dst, W_dst = 720, 1280 elif mode == 'train_sharp_flowx4': img_folder = '../../datasets/REDS/train_sharp_flowx4' lmdb_save_path = '../../datasets/REDS/train_sharp_flowx4.lmdb' H_dst, W_dst = 360, 320 n_thread = 40 ######################################################## if not lmdb_save_path.endswith('.lmdb'): raise ValueError("lmdb_save_path must end with \'lmdb\'.") if osp.exists(lmdb_save_path): print('Folder [{:s}] already exists. Exit...'.format(lmdb_save_path)) sys.exit(1) #### read all the image paths to a list print('Reading image path list ...') all_img_list = data_util._get_paths_from_images(img_folder) keys = [] for img_path in all_img_list: split_rlt = img_path.split('/') folder = split_rlt[-2] img_name = split_rlt[-1].split('.png')[0] keys.append(folder + '_' + img_name) if read_all_imgs: #### read all images to memory (multiprocessing) dataset = {} # store all image data. list cannot keep the order, use dict print('Read images with multiprocessing, #thread: {} ...'.format(n_thread)) pbar = util.ProgressBar(len(all_img_list)) def mycallback(arg): '''get the image data and update pbar''' key = arg[0] dataset[key] = arg[1] pbar.update('Reading {}'.format(key)) pool = Pool(n_thread) for path, key in zip(all_img_list, keys): pool.apply_async(read_image_worker, args=(path, key), callback=mycallback) pool.close() pool.join() print('Finish reading {} images.\nWrite lmdb...'.format(len(all_img_list))) #### create lmdb environment data_size_per_img = cv2.imread(all_img_list[0], cv2.IMREAD_UNCHANGED).nbytes print('data size per image is: ', data_size_per_img) data_size = data_size_per_img * len(all_img_list) env = lmdb.open(lmdb_save_path, map_size=data_size * 10) #### write data to lmdb pbar = util.ProgressBar(len(all_img_list)) txn = env.begin(write=True) for idx, (path, key) in enumerate(zip(all_img_list, keys)): pbar.update('Write {}'.format(key)) key_byte = key.encode('ascii') data = dataset[key] if read_all_imgs else cv2.imread(path, cv2.IMREAD_UNCHANGED) if 'flow' in mode: H, W = data.shape assert H == H_dst and W == W_dst, 'different shape.' else: H, W, C = data.shape assert H == H_dst and W == W_dst and C == 3, 'different shape.' txn.put(key_byte, data) if not read_all_imgs and idx % BATCH == 0: txn.commit() txn = env.begin(write=True) txn.commit() env.close() print('Finish writing lmdb.') #### create meta information meta_info = {} meta_info['name'] = 'REDS_{}_wval'.format(mode) channel = 1 if 'flow' in mode else 3 meta_info['resolution'] = '{}_{}_{}'.format(channel, H_dst, W_dst) meta_info['keys'] = keys pickle.dump(meta_info, open(osp.join(lmdb_save_path, 'meta_info.pkl'), "wb")) print('Finish creating lmdb meta info.') def test_lmdb(dataroot, dataset='REDS'): env = lmdb.open(dataroot, readonly=True, lock=False, readahead=False, meminit=False) meta_info = pickle.load(open(osp.join(dataroot, 'meta_info.pkl'), "rb")) print('Name: ', meta_info['name']) print('Resolution: ', meta_info['resolution']) print('# keys: ', len(meta_info['keys'])) # read one image if dataset == 'vimeo90k': key = '00001_0001_4' else: key = '000_00000000' print('Reading {} for test.'.format(key)) with env.begin(write=False) as txn: buf = txn.get(key.encode('ascii')) img_flat = np.frombuffer(buf, dtype=np.uint8) C, H, W = [int(s) for s in meta_info['resolution'].split('_')] img = img_flat.reshape(H, W, C) cv2.imwrite('test.png', img) if __name__ == "__main__": main()