Extract subimages mod

This commit is contained in:
James Betker 2020-05-23 21:07:41 -06:00
parent 90073fc761
commit 445e7e7053

View File

@ -13,18 +13,19 @@ import data.util as data_util # noqa: E402
def main(): def main():
mode = 'single' # single (one input folder) | pair (extract corresponding GT and LR pairs) mode = 'single' # single (one input folder) | pair (extract corresponding GT and LR pairs)
split_img = False
opt = {} opt = {}
opt['n_thread'] = 20 opt['n_thread'] = 20
opt['compression_level'] = 3 # 3 is the default value in cv2 opt['compression_level'] = 3 # 3 is the default value in cv2
# CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and longer # CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and longer
# compression time. If read raw images during training, use 0 for faster IO speed. # compression time. If read raw images during training, use 0 for faster IO speed.
if mode == 'single': if mode == 'single':
opt['input_folder'] = 'D:\\vix_cropped' opt['input_folder'] = 'Z:\\4k6k\\datasets\\adrianna\\adrianna_vids\\images'
opt['save_folder'] = 'D:\\vix_tiled' opt['save_folder'] = 'Z:\\4k6k\\datasets\\adrianna\\adrianna_vids\\tiled'
opt['crop_sz'] = 800 # the size of each sub-image opt['crop_sz'] = 64 # the size of each sub-image
opt['step'] = 640 # step of the sliding crop window opt['step'] = 48 # step of the sliding crop window
opt['thres_sz'] = 200 # size threshold opt['thres_sz'] = 20 # size threshold
extract_signle(opt) extract_single(opt, split_img)
elif mode == 'pair': elif mode == 'pair':
GT_folder = '../../datasets/div2k/DIV2K_train_HR' GT_folder = '../../datasets/div2k/DIV2K_train_HR'
LR_folder = '../../datasets/div2k/DIV2K_train_LR_bicubic/X4' LR_folder = '../../datasets/div2k/DIV2K_train_LR_bicubic/X4'
@ -60,14 +61,14 @@ def main():
opt['crop_sz'] = crop_sz opt['crop_sz'] = crop_sz
opt['step'] = step opt['step'] = step
opt['thres_sz'] = thres_sz opt['thres_sz'] = thres_sz
extract_signle(opt) extract_single(opt)
print('process LR...') print('process LR...')
opt['input_folder'] = LR_folder opt['input_folder'] = LR_folder
opt['save_folder'] = save_LR_folder opt['save_folder'] = save_LR_folder
opt['crop_sz'] = crop_sz // scale_ratio opt['crop_sz'] = crop_sz // scale_ratio
opt['step'] = step // scale_ratio opt['step'] = step // scale_ratio
opt['thres_sz'] = thres_sz // scale_ratio opt['thres_sz'] = thres_sz // scale_ratio
extract_signle(opt) extract_single(opt)
assert len(data_util._get_paths_from_images(save_GT_folder)) == len( assert len(data_util._get_paths_from_images(save_GT_folder)) == len(
data_util._get_paths_from_images( data_util._get_paths_from_images(
save_LR_folder)), 'different length of save_GT_folder and save_LR_folder.' save_LR_folder)), 'different length of save_GT_folder and save_LR_folder.'
@ -75,7 +76,7 @@ def main():
raise ValueError('Wrong mode.') raise ValueError('Wrong mode.')
def extract_signle(opt): def extract_single(opt, split_img=False):
input_folder = opt['input_folder'] input_folder = opt['input_folder']
save_folder = opt['save_folder'] save_folder = opt['save_folder']
if not osp.exists(save_folder): if not osp.exists(save_folder):
@ -93,13 +94,17 @@ def extract_signle(opt):
pool = Pool(opt['n_thread']) pool = Pool(opt['n_thread'])
for path in img_list: for path in img_list:
pool.apply_async(worker, args=(path, opt), callback=update) if split_img:
pool.apply_async(worker, args=(path, opt, True, False), callback=update)
pool.apply_async(worker, args=(path, opt, True, True), callback=update)
else:
pool.apply_async(worker, args=(path, opt), callback=update)
pool.close() pool.close()
pool.join() pool.join()
print('All subprocesses done.') print('All subprocesses done.')
def worker(path, opt): def worker(path, opt, split_mode=False, left_img=True):
crop_sz = opt['crop_sz'] crop_sz = opt['crop_sz']
step = opt['step'] step = opt['step']
thres_sz = opt['thres_sz'] thres_sz = opt['thres_sz']
@ -113,6 +118,22 @@ def worker(path, opt):
h, w, c = img.shape h, w, c = img.shape
else: else:
raise ValueError('Wrong image shape - {}'.format(n_channels)) raise ValueError('Wrong image shape - {}'.format(n_channels))
# Uncomment to filter any image that doesnt meet a threshold size.
#if w < 3000:
# return
left = 0
right = w
if split_mode:
if left_img:
left = 0
right = int(w/2)
else:
left = int(w/2)
right = w
w = int(w/2)
img = img[:, left:right]
h_space = np.arange(0, h - crop_sz + 1, step) h_space = np.arange(0, h - crop_sz + 1, step)
if h - (h_space[-1] + crop_sz) > thres_sz: if h - (h_space[-1] + crop_sz) > thres_sz:
@ -130,9 +151,11 @@ def worker(path, opt):
else: else:
crop_img = img[x:x + crop_sz, y:y + crop_sz, :] crop_img = img[x:x + crop_sz, y:y + crop_sz, :]
crop_img = np.ascontiguousarray(crop_img) crop_img = np.ascontiguousarray(crop_img)
# If this fails, change it and the imwrite below to the write extension.
assert img_name.contains(".png")
cv2.imwrite( cv2.imwrite(
osp.join(opt['save_folder'], osp.join(opt['save_folder'],
img_name.replace('.png', '_s{:03d}.png'.format(index))), crop_img, img_name.replace('.png', '_l{:05d}_s{:03d}.png'.format(left, index))), crop_img,
[cv2.IMWRITE_PNG_COMPRESSION, opt['compression_level']]) [cv2.IMWRITE_PNG_COMPRESSION, opt['compression_level']])
return 'Processing {:s} ...'.format(img_name) return 'Processing {:s} ...'.format(img_name)