2020-09-04 21:30:34 +00:00
""" A multi-thread tool to crop large images to sub-images for faster IO. """
import os
import os . path as osp
import numpy as np
import cv2
from PIL import Image
import data . util as data_util # noqa: E402
import torch . utils . data as data
from tqdm import tqdm
2020-09-17 19:30:51 +00:00
import torch
2020-09-04 21:30:34 +00:00
def main ( ) :
split_img = False
opt = { }
2020-11-24 20:20:12 +00:00
opt [ ' n_thread ' ] = 8
2020-09-04 21:30:34 +00:00
opt [ ' compression_level ' ] = 90 # JPEG compression quality rating.
# CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and longer
# compression time. If read raw images during training, use 0 for faster IO speed.
2020-09-17 19:30:51 +00:00
2020-10-24 17:57:39 +00:00
opt [ ' dest ' ] = ' file '
2020-12-01 18:11:30 +00:00
opt [ ' input_folder ' ] = ' F: \\ 4k6k \\ datasets \\ ns_images \\ fkaw \\ images '
opt [ ' save_folder ' ] = ' F: \\ 4k6k \\ datasets \\ ns_images \\ vixen \\ 512_with_ref_and_fkaw '
opt [ ' crop_sz ' ] = [ 1024 , 2048 ] # the size of each sub-image
opt [ ' step ' ] = [ 1024 , 2048 ] # step of the sliding crop window
opt [ ' thres_sz ' ] = 512 # size threshold
2020-10-24 17:57:39 +00:00
opt [ ' resize_final_img ' ] = [ .5 , .25 ]
opt [ ' only_resize ' ] = False
2020-11-24 20:20:12 +00:00
opt [ ' vertical_split ' ] = False
2020-12-01 18:11:30 +00:00
opt [ ' input_image_max_size_before_being_halved ' ] = 5500 # As described, images larger than this dimensional size will be halved before anything else is done.
# This helps prevent images from cameras with "false-megapixels" from polluting the dataset.
# False-megapixel=lots of noise at ultra-high res.
2020-10-24 17:57:39 +00:00
save_folder = opt [ ' save_folder ' ]
if not osp . exists ( save_folder ) :
os . makedirs ( save_folder )
print ( ' mkdir [ {:s} ] ... ' . format ( save_folder ) )
if opt [ ' dest ' ] == ' lmdb ' :
writer = LmdbWriter ( save_folder )
2020-09-04 21:30:34 +00:00
else :
2020-10-24 17:57:39 +00:00
writer = FileWriter ( save_folder )
extract_single ( opt , writer )
2020-09-04 21:30:34 +00:00
class LmdbWriter :
def __init__ ( self , lmdb_path , max_mem_size = 30 * 1024 * 1024 * 1024 , write_freq = 5000 ) :
self . db = lmdb . open ( lmdb_path , subdir = True ,
map_size = max_mem_size , readonly = False ,
meminit = False , map_async = True )
self . txn = self . db . begin ( write = True )
self . ref_id = 0
self . tile_ids = { }
self . writes = 0
self . write_freq = write_freq
self . keys = [ ]
# Writes the given reference image to the db and returns its ID.
2020-09-17 19:30:51 +00:00
def write_reference_image ( self , ref_img , _ ) :
2020-09-04 21:30:34 +00:00
id = self . ref_id
self . ref_id + = 1
self . write_image ( id , ref_img [ 0 ] , ref_img [ 1 ] )
return id
# Writes a tile image to the db given a reference image and returns its ID.
def write_tile_image ( self , ref_id , tile_image ) :
next_tile_id = 0 if ref_id not in self . tile_ids . keys ( ) else self . tile_ids [ ref_id ]
self . tile_ids [ ref_id ] = next_tile_id + 1
full_id = " %i _ %i " % ( ref_id , next_tile_id )
self . write_image ( full_id , tile_image [ 0 ] , tile_image [ 1 ] )
self . keys . append ( full_id )
return full_id
# Writes an image directly to the db with the given reference image and center point.
def write_image ( self , id , img , center_point ) :
self . txn . put ( u ' {} ' . format ( id ) . encode ( ' ascii ' ) , pyarrow . serialize ( img ) . to_buffer ( ) , pyarrow . serialize ( center_point ) . to_buffer ( ) )
self . writes + = 1
if self . writes % self . write_freq == 0 :
self . txn . commit ( )
self . txn = self . db . begin ( write = True )
def close ( self ) :
self . txn . commit ( )
with self . db . begin ( write = True ) as txn :
txn . put ( b ' __keys__ ' , pyarrow . serialize ( self . keys ) . to_buffer ( ) )
txn . put ( b ' __len__ ' , pyarrow . serialize ( len ( self . keys ) ) . to_buffer ( ) )
self . db . sync ( )
self . db . close ( )
2020-09-17 19:30:51 +00:00
class FileWriter :
def __init__ ( self , folder ) :
self . folder = folder
self . next_unique_id = 0
self . ref_center_points = { } # Maps ref_img basename to a dict of image IDs:center points
self . ref_ids_to_names = { }
def get_next_unique_id ( self ) :
id = self . next_unique_id
self . next_unique_id + = 1
return id
def save_image ( self , ref_path , img_name , img ) :
save_path = osp . join ( self . folder , ref_path )
os . makedirs ( save_path , exist_ok = True )
f = open ( osp . join ( save_path , img_name ) , " wb " )
f . write ( img )
f . close ( )
# Writes the given reference image to the db and returns its ID.
def write_reference_image ( self , ref_img , path ) :
2020-09-25 22:37:54 +00:00
ref_img , _ , _ = ref_img # Encoded with a center point, which is irrelevant for the reference image.
2020-09-17 19:30:51 +00:00
img_name = osp . basename ( path ) . replace ( " .jpg " , " " ) . replace ( " .png " , " " )
self . ref_center_points [ img_name ] = { }
self . save_image ( img_name , " ref.jpg " , ref_img )
id = self . get_next_unique_id ( )
self . ref_ids_to_names [ id ] = img_name
return id
# Writes a tile image to the db given a reference image and returns its ID.
def write_tile_image ( self , ref_id , tile_image ) :
id = self . get_next_unique_id ( )
ref_name = self . ref_ids_to_names [ ref_id ]
2020-09-25 22:37:54 +00:00
img , center , tile_sz = tile_image
self . ref_center_points [ ref_name ] [ id ] = center , tile_sz
2020-09-17 19:30:51 +00:00
self . save_image ( ref_name , " %08i .jpg " % ( id , ) , img )
return id
2020-09-25 22:37:54 +00:00
def flush ( self ) :
2020-09-17 19:30:51 +00:00
for ref_name , cps in self . ref_center_points . items ( ) :
torch . save ( cps , osp . join ( self . folder , ref_name , " centers.pt " ) )
2020-09-25 22:37:54 +00:00
self . ref_center_points = { }
def close ( self ) :
self . flush ( )
2020-09-17 19:30:51 +00:00
2020-09-04 21:30:34 +00:00
class TiledDataset ( data . Dataset ) :
2020-10-24 17:57:39 +00:00
def __init__ ( self , opt ) :
self . split_mode = opt [ ' vertical_split ' ]
2020-09-04 21:30:34 +00:00
self . opt = opt
input_folder = opt [ ' input_folder ' ]
self . images = data_util . _get_paths_from_images ( input_folder )
def __getitem__ ( self , index ) :
if self . split_mode :
2020-10-24 17:57:39 +00:00
return ( self . get ( index , True , True ) , self . get ( index , True , False ) )
2020-09-04 21:30:34 +00:00
else :
2020-10-24 17:57:39 +00:00
# Wrap in a tuple to align with split mode.
return ( self . get ( index , False , False ) , None )
2020-09-25 22:37:54 +00:00
2020-10-24 17:57:39 +00:00
def get_for_scale ( self , img , crop_sz , step , resize_factor , ref_resize_factor ) :
2020-09-04 21:30:34 +00:00
thres_sz = self . opt [ ' thres_sz ' ]
h , w , c = img . shape
2020-11-26 18:50:38 +00:00
if crop_sz > h :
return [ ]
2020-09-04 21:30:34 +00:00
h_space = np . arange ( 0 , h - crop_sz + 1 , step )
if h - ( h_space [ - 1 ] + crop_sz ) > thres_sz :
h_space = np . append ( h_space , h - crop_sz )
w_space = np . arange ( 0 , w - crop_sz + 1 , step )
if w - ( w_space [ - 1 ] + crop_sz ) > thres_sz :
w_space = np . append ( w_space , w - crop_sz )
index = 0
2020-09-17 19:30:51 +00:00
tile_dim = int ( crop_sz * resize_factor )
dsize = ( tile_dim , tile_dim )
results = [ ]
2020-09-04 21:30:34 +00:00
for x in h_space :
for y in w_space :
index + = 1
crop_img = img [ x : x + crop_sz , y : y + crop_sz , : ]
2020-09-25 22:37:54 +00:00
# Center point needs to be resized by ref_resize_factor - since it is relative to the reference image.
center_point = ( int ( ( x + crop_sz / / 2 ) / / ref_resize_factor ) , int ( ( y + crop_sz / / 2 ) / / ref_resize_factor ) )
2020-09-04 21:30:34 +00:00
crop_img = np . ascontiguousarray ( crop_img )
if ' resize_final_img ' in self . opt . keys ( ) :
crop_img = cv2 . resize ( crop_img , dsize , interpolation = cv2 . INTER_AREA )
success , buffer = cv2 . imencode ( " .jpg " , crop_img , [ cv2 . IMWRITE_JPEG_QUALITY , self . opt [ ' compression_level ' ] ] )
assert success
2020-09-25 22:37:54 +00:00
results . append ( ( buffer , center_point , int ( crop_sz / / ref_resize_factor ) ) )
2020-09-04 21:30:34 +00:00
return results
2020-09-17 19:30:51 +00:00
def get ( self , index , split_mode , left_img ) :
path = self . images [ index ]
img = cv2 . imread ( path , cv2 . IMREAD_UNCHANGED )
2020-11-24 20:20:12 +00:00
if img is None or len ( img . shape ) == 2 :
return None
2020-09-17 19:30:51 +00:00
h , w , c = img . shape
2020-12-01 18:11:30 +00:00
if max ( h , w ) > self . opt [ ' input_image_max_size_before_being_halved ' ] :
h = h / / 2
w = w / / 2
img = cv2 . resize ( img , ( w , h ) , interpolation = cv2 . INTER_AREA )
#print("Resizing to ", img.shape)
2020-09-17 19:30:51 +00:00
# Uncomment to filter any image that doesnt meet a threshold size.
2020-11-24 20:20:12 +00:00
if min ( h , w ) < 512 :
2020-09-17 19:30:51 +00:00
return None
2020-10-24 17:57:39 +00:00
# Greyscale not supported.
if len ( img . shape ) == 2 :
return None
# Handle splitting the image if needed.
2020-09-17 19:30:51 +00:00
left = 0
right = w
if split_mode :
if left_img :
left = 0
2020-10-24 17:57:39 +00:00
right = w / / 2
2020-09-17 19:30:51 +00:00
else :
2020-10-24 17:57:39 +00:00
left = w / / 2
2020-09-17 19:30:51 +00:00
right = w
img = img [ : , left : right ]
2020-10-24 17:57:39 +00:00
# We must convert the image into a square.
dim = min ( h , w )
if split_mode :
# Crop the image towards the center, which makes more sense in split mode.
if left_img :
img = img [ - dim : , - dim : , : ]
else :
img = img [ : dim , : dim , : ]
else :
# Crop the image so that only the center is left, since this is often the most salient part of the image.
img = img [ ( h - dim ) / / 2 : dim + ( h - dim ) / / 2 , ( w - dim ) / / 2 : dim + ( w - dim ) / / 2 , : ]
h , w , c = img . shape
2020-09-17 19:30:51 +00:00
tile_dim = int ( self . opt [ ' crop_sz ' ] [ 0 ] * self . opt [ ' resize_final_img ' ] [ 0 ] )
dsize = ( tile_dim , tile_dim )
2020-09-25 22:37:54 +00:00
ref_resize_factor = h / tile_dim
2020-09-17 19:30:51 +00:00
# Reference image should always be first entry in results.
ref_img = cv2 . resize ( img , dsize , interpolation = cv2 . INTER_AREA )
success , ref_buffer = cv2 . imencode ( " .jpg " , ref_img , [ cv2 . IMWRITE_JPEG_QUALITY , self . opt [ ' compression_level ' ] ] )
assert success
2020-09-25 22:37:54 +00:00
results = [ ( ref_buffer , ( - 1 , - 1 ) , ( - 1 , - 1 ) ) ]
2020-09-17 19:30:51 +00:00
2020-09-25 22:37:54 +00:00
for crop_sz , resize_factor , step in zip ( self . opt [ ' crop_sz ' ] , self . opt [ ' resize_final_img ' ] , self . opt [ ' step ' ] ) :
2020-10-24 17:57:39 +00:00
results . extend ( self . get_for_scale ( img , crop_sz , step , resize_factor , ref_resize_factor ) )
2020-09-17 19:30:51 +00:00
return results , path
2020-09-04 21:30:34 +00:00
def __len__ ( self ) :
return len ( self . images )
def identity ( x ) :
return x
2020-10-24 17:57:39 +00:00
def extract_single ( opt , writer ) :
dataset = TiledDataset ( opt )
2020-09-04 21:30:34 +00:00
dataloader = data . DataLoader ( dataset , num_workers = opt [ ' n_thread ' ] , collate_fn = identity )
tq = tqdm ( dataloader )
2020-10-24 17:57:39 +00:00
for spl_imgs in tq :
if spl_imgs is None :
2020-09-04 21:30:34 +00:00
continue
2020-10-24 17:57:39 +00:00
spl_imgs = spl_imgs [ 0 ]
for imgs , lbl in zip ( list ( spl_imgs ) , [ ' left ' , ' right ' ] ) :
if imgs is None :
continue
imgs , path = imgs
if imgs is None or len ( imgs ) < = 1 :
continue
path = path + " _ " + lbl
ref_id = writer . write_reference_image ( imgs [ 0 ] , path )
for tile in imgs [ 1 : ] :
writer . write_tile_image ( ref_id , tile )
writer . flush ( )
2020-09-17 19:30:51 +00:00
writer . close ( )
2020-09-04 21:30:34 +00:00
if __name__ == ' __main__ ' :
main ( )