forked from mrq/DL-Art-School
11155aead4
This has been a long time coming. Cleans up messy "GT" nomenclature and simplifies ExtensibleTraner.feed_data
412 lines
16 KiB
Python
412 lines
16 KiB
Python
"""Create lmdb files for [General images (291 images/DIV2K) | Vimeo90K | REDS] training datasets"""
|
|
|
|
import sys
|
|
import os.path as osp
|
|
import glob
|
|
import pickle
|
|
from multiprocessing import Pool
|
|
import numpy as np
|
|
import lmdb
|
|
import cv2
|
|
|
|
sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__))))
|
|
import data.util as data_util # noqa: E402
|
|
import utils.util as util # noqa: E402
|
|
|
|
|
|
def main():
|
|
dataset = 'DIV2K_demo' # vimeo90K | REDS | general (e.g., DIV2K, 291) | DIV2K_demo |test
|
|
mode = 'hq' # used for vimeo90k and REDS datasets
|
|
# vimeo90k: GT | LR | flow
|
|
# REDS: train_sharp, train_sharp_bicubic, train_blur_bicubic, train_blur, train_blur_comp
|
|
# train_sharp_flowx4
|
|
if dataset == 'vimeo90k':
|
|
vimeo90k(mode)
|
|
elif dataset == 'REDS':
|
|
REDS(mode)
|
|
elif dataset == 'general':
|
|
opt = {}
|
|
opt['img_folder'] = '../../datasets/DIV2K/DIV2K800_sub'
|
|
opt['lmdb_save_path'] = '../../datasets/DIV2K/DIV2K800_sub.lmdb'
|
|
opt['name'] = 'DIV2K800_sub_GT'
|
|
general_image_folder(opt)
|
|
elif dataset == 'DIV2K_demo':
|
|
opt = {}
|
|
## GT
|
|
opt['img_folder'] = '../../datasets/DIV2K/DIV2K800_sub'
|
|
opt['lmdb_save_path'] = '../../datasets/DIV2K/DIV2K800_sub.lmdb'
|
|
opt['name'] = 'DIV2K800_sub_GT'
|
|
general_image_folder(opt)
|
|
## LR
|
|
opt['img_folder'] = '../../datasets/DIV2K/DIV2K800_sub_bicLRx4'
|
|
opt['lmdb_save_path'] = '../../datasets/DIV2K/DIV2K800_sub_bicLRx4.lmdb'
|
|
opt['name'] = 'DIV2K800_sub_bicLRx4'
|
|
general_image_folder(opt)
|
|
elif dataset == 'test':
|
|
test_lmdb('../../datasets/REDS/train_sharp_wval.lmdb', 'REDS')
|
|
|
|
|
|
def read_image_worker(path, key):
|
|
img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
|
|
return (key, img)
|
|
|
|
|
|
def general_image_folder(opt):
|
|
"""Create lmdb for general image folders
|
|
Users should define the keys, such as: '0321_s035' for DIV2K sub-images
|
|
If all the images have the same resolution, it will only store one copy of resolution info.
|
|
Otherwise, it will store every resolution info.
|
|
"""
|
|
#### configurations
|
|
read_all_imgs = False # whether real all images to memory with multiprocessing
|
|
# Set False for use limited memory
|
|
BATCH = 5000 # After BATCH images, lmdb commits, if read_all_imgs = False
|
|
n_thread = 40
|
|
########################################################
|
|
img_folder = opt['img_folder']
|
|
lmdb_save_path = opt['lmdb_save_path']
|
|
meta_info = {'name': opt['name']}
|
|
if not lmdb_save_path.endswith('.lmdb'):
|
|
raise ValueError("lmdb_save_path must end with \'lmdb\'.")
|
|
if osp.exists(lmdb_save_path):
|
|
print('Folder [{:s}] already exists. Exit...'.format(lmdb_save_path))
|
|
sys.exit(1)
|
|
|
|
#### read all the image paths to a list
|
|
print('Reading image path list ...')
|
|
all_img_list = sorted(glob.glob(osp.join(img_folder, '*')))
|
|
keys = []
|
|
for img_path in all_img_list:
|
|
keys.append(osp.splitext(osp.basename(img_path))[0])
|
|
|
|
if read_all_imgs:
|
|
#### read all images to memory (multiprocessing)
|
|
dataset = {} # store all image data. list cannot keep the order, use dict
|
|
print('Read images with multiprocessing, #thread: {} ...'.format(n_thread))
|
|
pbar = util.ProgressBar(len(all_img_list))
|
|
|
|
def mycallback(arg):
|
|
'''get the image data and update pbar'''
|
|
key = arg[0]
|
|
dataset[key] = arg[1]
|
|
pbar.update('Reading {}'.format(key))
|
|
|
|
pool = Pool(n_thread)
|
|
for path, key in zip(all_img_list, keys):
|
|
pool.apply_async(read_image_worker, args=(path, key), callback=mycallback)
|
|
pool.close()
|
|
pool.join()
|
|
print('Finish reading {} images.\nWrite lmdb...'.format(len(all_img_list)))
|
|
|
|
#### create lmdb environment
|
|
data_size_per_img = cv2.imread(all_img_list[0], cv2.IMREAD_UNCHANGED).nbytes
|
|
print('data size per image is: ', data_size_per_img)
|
|
data_size = data_size_per_img * len(all_img_list)
|
|
env = lmdb.open(lmdb_save_path, map_size=data_size * 10)
|
|
|
|
#### write data to lmdb
|
|
pbar = util.ProgressBar(len(all_img_list))
|
|
txn = env.begin(write=True)
|
|
resolutions = []
|
|
for idx, (path, key) in enumerate(zip(all_img_list, keys)):
|
|
pbar.update('Write {}'.format(key))
|
|
key_byte = key.encode('ascii')
|
|
data = dataset[key] if read_all_imgs else cv2.imread(path, cv2.IMREAD_UNCHANGED)
|
|
if data.ndim == 2:
|
|
H, W = data.shape
|
|
C = 1
|
|
else:
|
|
H, W, C = data.shape
|
|
txn.put(key_byte, data)
|
|
resolutions.append('{:d}_{:d}_{:d}'.format(C, H, W))
|
|
if not read_all_imgs and idx % BATCH == 0:
|
|
txn.commit()
|
|
txn = env.begin(write=True)
|
|
txn.commit()
|
|
env.close()
|
|
print('Finish writing lmdb.')
|
|
|
|
#### create meta information
|
|
# check whether all the images are the same size
|
|
assert len(keys) == len(resolutions)
|
|
if len(set(resolutions)) <= 1:
|
|
meta_info['resolution'] = [resolutions[0]]
|
|
meta_info['keys'] = keys
|
|
print('All images have the same resolution. Simplify the meta info.')
|
|
else:
|
|
meta_info['resolution'] = resolutions
|
|
meta_info['keys'] = keys
|
|
print('Not all images have the same resolution. Save meta info for each image.')
|
|
|
|
pickle.dump(meta_info, open(osp.join(lmdb_save_path, 'meta_info.pkl'), "wb"))
|
|
print('Finish creating lmdb meta info.')
|
|
|
|
|
|
def vimeo90k(mode):
|
|
"""Create lmdb for the Vimeo90K dataset, each image with a fixed size
|
|
GT: [3, 256, 448]
|
|
Now only need the 4th frame, e.g., 00001_0001_4
|
|
LR: [3, 64, 112]
|
|
1st - 7th frames, e.g., 00001_0001_1, ..., 00001_0001_7
|
|
key:
|
|
Use the folder and subfolder names, w/o the frame index, e.g., 00001_0001
|
|
|
|
flow: downsampled flow: [3, 360, 320], keys: 00001_0001_4_[p3, p2, p1, n1, n2, n3]
|
|
Each flow is calculated with GT images by PWCNet and then downsampled by 1/4
|
|
Flow map is quantized by mmcv and saved in png format
|
|
"""
|
|
#### configurations
|
|
read_all_imgs = False # whether real all images to memory with multiprocessing
|
|
# Set False for use limited memory
|
|
BATCH = 5000 # After BATCH images, lmdb commits, if read_all_imgs = False
|
|
if mode == 'hq':
|
|
img_folder = '../../datasets/vimeo90k/vimeo_septuplet/sequences'
|
|
lmdb_save_path = '../../datasets/vimeo90k/vimeo90k_train_GT.lmdb'
|
|
txt_file = '../../datasets/vimeo90k/vimeo_septuplet/sep_trainlist.txt'
|
|
H_dst, W_dst = 256, 448
|
|
elif mode == 'LR':
|
|
img_folder = '../../datasets/vimeo90k/vimeo_septuplet_matlabLRx4/sequences'
|
|
lmdb_save_path = '../../datasets/vimeo90k/vimeo90k_train_LR7frames.lmdb'
|
|
txt_file = '../../datasets/vimeo90k/vimeo_septuplet/sep_trainlist.txt'
|
|
H_dst, W_dst = 64, 112
|
|
elif mode == 'flow':
|
|
img_folder = '../../datasets/vimeo90k/vimeo_septuplet/sequences_flowx4'
|
|
lmdb_save_path = '../../datasets/vimeo90k/vimeo90k_train_flowx4.lmdb'
|
|
txt_file = '../../datasets/vimeo90k/vimeo_septuplet/sep_trainlist.txt'
|
|
H_dst, W_dst = 128, 112
|
|
else:
|
|
raise ValueError('Wrong dataset mode: {}'.format(mode))
|
|
n_thread = 40
|
|
########################################################
|
|
if not lmdb_save_path.endswith('.lmdb'):
|
|
raise ValueError("lmdb_save_path must end with \'lmdb\'.")
|
|
if osp.exists(lmdb_save_path):
|
|
print('Folder [{:s}] already exists. Exit...'.format(lmdb_save_path))
|
|
sys.exit(1)
|
|
|
|
#### read all the image paths to a list
|
|
print('Reading image path list ...')
|
|
with open(txt_file) as f:
|
|
train_l = f.readlines()
|
|
train_l = [v.strip() for v in train_l]
|
|
all_img_list = []
|
|
keys = []
|
|
for line in train_l:
|
|
folder = line.split('/')[0]
|
|
sub_folder = line.split('/')[1]
|
|
all_img_list.extend(glob.glob(osp.join(img_folder, folder, sub_folder, '*')))
|
|
if mode == 'flow':
|
|
for j in range(1, 4):
|
|
keys.append('{}_{}_4_n{}'.format(folder, sub_folder, j))
|
|
keys.append('{}_{}_4_p{}'.format(folder, sub_folder, j))
|
|
else:
|
|
for j in range(7):
|
|
keys.append('{}_{}_{}'.format(folder, sub_folder, j + 1))
|
|
all_img_list = sorted(all_img_list)
|
|
keys = sorted(keys)
|
|
if mode == 'hq': # only read the 4th frame for the GT mode
|
|
print('Only keep the 4th frame.')
|
|
all_img_list = [v for v in all_img_list if v.endswith('im4.png')]
|
|
keys = [v for v in keys if v.endswith('_4')]
|
|
|
|
if read_all_imgs:
|
|
#### read all images to memory (multiprocessing)
|
|
dataset = {} # store all image data. list cannot keep the order, use dict
|
|
print('Read images with multiprocessing, #thread: {} ...'.format(n_thread))
|
|
pbar = util.ProgressBar(len(all_img_list))
|
|
|
|
def mycallback(arg):
|
|
"""get the image data and update pbar"""
|
|
key = arg[0]
|
|
dataset[key] = arg[1]
|
|
pbar.update('Reading {}'.format(key))
|
|
|
|
pool = Pool(n_thread)
|
|
for path, key in zip(all_img_list, keys):
|
|
pool.apply_async(read_image_worker, args=(path, key), callback=mycallback)
|
|
pool.close()
|
|
pool.join()
|
|
print('Finish reading {} images.\nWrite lmdb...'.format(len(all_img_list)))
|
|
|
|
#### write data to lmdb
|
|
data_size_per_img = cv2.imread(all_img_list[0], cv2.IMREAD_UNCHANGED).nbytes
|
|
print('data size per image is: ', data_size_per_img)
|
|
data_size = data_size_per_img * len(all_img_list)
|
|
env = lmdb.open(lmdb_save_path, map_size=data_size * 10)
|
|
txn = env.begin(write=True)
|
|
pbar = util.ProgressBar(len(all_img_list))
|
|
for idx, (path, key) in enumerate(zip(all_img_list, keys)):
|
|
pbar.update('Write {}'.format(key))
|
|
key_byte = key.encode('ascii')
|
|
data = dataset[key] if read_all_imgs else cv2.imread(path, cv2.IMREAD_UNCHANGED)
|
|
if 'flow' in mode:
|
|
H, W = data.shape
|
|
assert H == H_dst and W == W_dst, 'different shape.'
|
|
else:
|
|
H, W, C = data.shape
|
|
assert H == H_dst and W == W_dst and C == 3, 'different shape.'
|
|
txn.put(key_byte, data)
|
|
if not read_all_imgs and idx % BATCH == 0:
|
|
txn.commit()
|
|
txn = env.begin(write=True)
|
|
txn.commit()
|
|
env.close()
|
|
print('Finish writing lmdb.')
|
|
|
|
#### create meta information
|
|
meta_info = {}
|
|
if mode == 'hq':
|
|
meta_info['name'] = 'Vimeo90K_train_GT'
|
|
elif mode == 'lq':
|
|
meta_info['name'] = 'Vimeo90K_train_LR'
|
|
elif mode == 'flow':
|
|
meta_info['name'] = 'Vimeo90K_train_flowx4'
|
|
channel = 1 if 'flow' in mode else 3
|
|
meta_info['resolution'] = '{}_{}_{}'.format(channel, H_dst, W_dst)
|
|
key_set = set()
|
|
for key in keys:
|
|
if mode == 'flow':
|
|
a, b, _, _ = key.split('_')
|
|
else:
|
|
a, b, _ = key.split('_')
|
|
key_set.add('{}_{}'.format(a, b))
|
|
meta_info['keys'] = list(key_set)
|
|
pickle.dump(meta_info, open(osp.join(lmdb_save_path, 'meta_info.pkl'), "wb"))
|
|
print('Finish creating lmdb meta info.')
|
|
|
|
|
|
def REDS(mode):
|
|
"""Create lmdb for the REDS dataset, each image with a fixed size
|
|
GT: [3, 720, 1280], key: 000_00000000
|
|
LR: [3, 180, 320], key: 000_00000000
|
|
key: 000_00000000
|
|
|
|
flow: downsampled flow: [3, 360, 320], keys: 000_00000005_[p2, p1, n1, n2]
|
|
Each flow is calculated with the GT images by PWCNet and then downsampled by 1/4
|
|
Flow map is quantized by mmcv and saved in png format
|
|
"""
|
|
#### configurations
|
|
read_all_imgs = False # whether real all images to memory with multiprocessing
|
|
# Set False for use limited memory
|
|
BATCH = 5000 # After BATCH images, lmdb commits, if read_all_imgs = False
|
|
if mode == 'train_sharp':
|
|
img_folder = '../../datasets/REDS/train_sharp'
|
|
lmdb_save_path = '../../datasets/REDS/train_sharp_wval.lmdb'
|
|
H_dst, W_dst = 720, 1280
|
|
elif mode == 'train_sharp_bicubic':
|
|
img_folder = '../../datasets/REDS/train_sharp_bicubic'
|
|
lmdb_save_path = '../../datasets/REDS/train_sharp_bicubic_wval.lmdb'
|
|
H_dst, W_dst = 180, 320
|
|
elif mode == 'train_blur_bicubic':
|
|
img_folder = '../../datasets/REDS/train_blur_bicubic'
|
|
lmdb_save_path = '../../datasets/REDS/train_blur_bicubic_wval.lmdb'
|
|
H_dst, W_dst = 180, 320
|
|
elif mode == 'train_blur':
|
|
img_folder = '../../datasets/REDS/train_blur'
|
|
lmdb_save_path = '../../datasets/REDS/train_blur_wval.lmdb'
|
|
H_dst, W_dst = 720, 1280
|
|
elif mode == 'train_blur_comp':
|
|
img_folder = '../../datasets/REDS/train_blur_comp'
|
|
lmdb_save_path = '../../datasets/REDS/train_blur_comp_wval.lmdb'
|
|
H_dst, W_dst = 720, 1280
|
|
elif mode == 'train_sharp_flowx4':
|
|
img_folder = '../../datasets/REDS/train_sharp_flowx4'
|
|
lmdb_save_path = '../../datasets/REDS/train_sharp_flowx4.lmdb'
|
|
H_dst, W_dst = 360, 320
|
|
n_thread = 40
|
|
########################################################
|
|
if not lmdb_save_path.endswith('.lmdb'):
|
|
raise ValueError("lmdb_save_path must end with \'lmdb\'.")
|
|
if osp.exists(lmdb_save_path):
|
|
print('Folder [{:s}] already exists. Exit...'.format(lmdb_save_path))
|
|
sys.exit(1)
|
|
|
|
#### read all the image paths to a list
|
|
print('Reading image path list ...')
|
|
all_img_list = data_util._get_paths_from_images(img_folder)
|
|
keys = []
|
|
for img_path in all_img_list:
|
|
split_rlt = img_path.split('/')
|
|
folder = split_rlt[-2]
|
|
img_name = split_rlt[-1].split('.png')[0]
|
|
keys.append(folder + '_' + img_name)
|
|
|
|
if read_all_imgs:
|
|
#### read all images to memory (multiprocessing)
|
|
dataset = {} # store all image data. list cannot keep the order, use dict
|
|
print('Read images with multiprocessing, #thread: {} ...'.format(n_thread))
|
|
pbar = util.ProgressBar(len(all_img_list))
|
|
|
|
def mycallback(arg):
|
|
'''get the image data and update pbar'''
|
|
key = arg[0]
|
|
dataset[key] = arg[1]
|
|
pbar.update('Reading {}'.format(key))
|
|
|
|
pool = Pool(n_thread)
|
|
for path, key in zip(all_img_list, keys):
|
|
pool.apply_async(read_image_worker, args=(path, key), callback=mycallback)
|
|
pool.close()
|
|
pool.join()
|
|
print('Finish reading {} images.\nWrite lmdb...'.format(len(all_img_list)))
|
|
|
|
#### create lmdb environment
|
|
data_size_per_img = cv2.imread(all_img_list[0], cv2.IMREAD_UNCHANGED).nbytes
|
|
print('data size per image is: ', data_size_per_img)
|
|
data_size = data_size_per_img * len(all_img_list)
|
|
env = lmdb.open(lmdb_save_path, map_size=data_size * 10)
|
|
|
|
#### write data to lmdb
|
|
pbar = util.ProgressBar(len(all_img_list))
|
|
txn = env.begin(write=True)
|
|
for idx, (path, key) in enumerate(zip(all_img_list, keys)):
|
|
pbar.update('Write {}'.format(key))
|
|
key_byte = key.encode('ascii')
|
|
data = dataset[key] if read_all_imgs else cv2.imread(path, cv2.IMREAD_UNCHANGED)
|
|
if 'flow' in mode:
|
|
H, W = data.shape
|
|
assert H == H_dst and W == W_dst, 'different shape.'
|
|
else:
|
|
H, W, C = data.shape
|
|
assert H == H_dst and W == W_dst and C == 3, 'different shape.'
|
|
txn.put(key_byte, data)
|
|
if not read_all_imgs and idx % BATCH == 0:
|
|
txn.commit()
|
|
txn = env.begin(write=True)
|
|
txn.commit()
|
|
env.close()
|
|
print('Finish writing lmdb.')
|
|
|
|
#### create meta information
|
|
meta_info = {}
|
|
meta_info['name'] = 'REDS_{}_wval'.format(mode)
|
|
channel = 1 if 'flow' in mode else 3
|
|
meta_info['resolution'] = '{}_{}_{}'.format(channel, H_dst, W_dst)
|
|
meta_info['keys'] = keys
|
|
pickle.dump(meta_info, open(osp.join(lmdb_save_path, 'meta_info.pkl'), "wb"))
|
|
print('Finish creating lmdb meta info.')
|
|
|
|
|
|
def test_lmdb(dataroot, dataset='REDS'):
|
|
env = lmdb.open(dataroot, readonly=True, lock=False, readahead=False, meminit=False)
|
|
meta_info = pickle.load(open(osp.join(dataroot, 'meta_info.pkl'), "rb"))
|
|
print('Name: ', meta_info['name'])
|
|
print('Resolution: ', meta_info['resolution'])
|
|
print('# keys: ', len(meta_info['keys']))
|
|
# read one image
|
|
if dataset == 'vimeo90k':
|
|
key = '00001_0001_4'
|
|
else:
|
|
key = '000_00000000'
|
|
print('Reading {} for test.'.format(key))
|
|
with env.begin(write=False) as txn:
|
|
buf = txn.get(key.encode('ascii'))
|
|
img_flat = np.frombuffer(buf, dtype=np.uint8)
|
|
C, H, W = [int(s) for s in meta_info['resolution'].split('_')]
|
|
img = img_flat.reshape(H, W, C)
|
|
cv2.imwrite('test.png', img)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|