From 327cdbe110382301125609036e03836ba8314a9d Mon Sep 17 00:00:00 2001 From: James Betker Date: Sat, 24 Oct 2020 11:57:39 -0600 Subject: [PATCH] Support configurable multi-modal training --- codes/multi_modal_train.py | 17 ++- codes/process_video.py | 7 + codes/scripts/extract_subimages_with_ref.py | 161 ++++++++------------ 3 files changed, 82 insertions(+), 103 deletions(-) diff --git a/codes/multi_modal_train.py b/codes/multi_modal_train.py index fd32ba17..3b4cf1cb 100644 --- a/codes/multi_modal_train.py +++ b/codes/multi_modal_train.py @@ -9,8 +9,13 @@ # models are shared. Your best bet is to have all models save state at the same time so that they all load ~ the same # state when re-started. import argparse + +import yaml + import train import utils.options as option +from utils.util import OrderedYaml + def main(master_opt, launcher): trainers = [] @@ -40,7 +45,11 @@ if __name__ == '__main__': #parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_imgset_chained_structured_trans_invariance.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') args = parser.parse_args() - opt = { - 'trainer_options': ['../options/teco.yml', '../options/exd.yml'] - } - main(opt, args.launcher) \ No newline at end of file + + Loader, Dumper = OrderedYaml() + with open(args.opt, mode='r') as f: + opt = yaml.load(f, Loader=Loader) + opt = { + 'trainer_options': ['../options/teco.yml', '../options/exd.yml'] + } + main(opt, args.launcher) \ No newline at end of file diff --git a/codes/process_video.py b/codes/process_video.py index 24a6181e..368e1785 100644 --- a/codes/process_video.py +++ b/codes/process_video.py @@ -155,6 +155,13 @@ if __name__ == "__main__": if recurrent_mode and first_frame: b, c, h, w = data['LQ'].shape recurrent_entry = torch.zeros((b,c,h*scale,w*scale), device=data['LQ'].device) + # Optionally swap out the 'generator' for the first frame to create a better image that the recurrent generator works off of. + if 'recurrent_hr_generator' in opt.keys(): + recurrent_gen = model.env['generators']['generator'] + model.env['generators']['generator'] = model.env['generators'][opt['recurrent_hr_generator']] + else: + model.env['generators']['generator'] = recurrent_gen + first_frame = False if recurrent_mode: data['recurrent'] = recurrent_entry diff --git a/codes/scripts/extract_subimages_with_ref.py b/codes/scripts/extract_subimages_with_ref.py index cb6c4465..0e058369 100644 --- a/codes/scripts/extract_subimages_with_ref.py +++ b/codes/scripts/extract_subimages_with_ref.py @@ -11,7 +11,6 @@ import torch def main(): - mode = 'single' # single (one input folder) | pair (extract corresponding GT and LR pairs) split_img = False opt = {} opt['n_thread'] = 2 @@ -19,75 +18,27 @@ def main(): # CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and longer # compression time. If read raw images during training, use 0 for faster IO speed. - if mode == 'single': - opt['dest'] = 'file' - opt['input_folder'] = 'F:\\4k6k\\datasets\\images\\fullvideo\\full_images' - opt['save_folder'] = 'F:\\4k6k\\datasets\\images\\fullvideo\\256_tiled' - opt['crop_sz'] = [512, 1024] # the size of each sub-image - opt['step'] = [512, 1024] # step of the sliding crop window - opt['thres_sz'] = 128 # size threshold - opt['resize_final_img'] = [.5, .25] - opt['only_resize'] = False + opt['dest'] = 'file' + opt['input_folder'] = 'F:\\4k6k\\datasets\\ns_images\\vr\\images_sized' + opt['save_folder'] = 'F:\\4k6k\\datasets\\ns_images\\vr\\paired_images' + opt['crop_sz'] = [512, 1024] # the size of each sub-image + opt['step'] = [512, 1024] # step of the sliding crop window + opt['thres_sz'] = 128 # size threshold + opt['resize_final_img'] = [.5, .25] + opt['only_resize'] = False + opt['vertical_split'] = True - save_folder = opt['save_folder'] - if not osp.exists(save_folder): - os.makedirs(save_folder) - print('mkdir [{:s}] ...'.format(save_folder)) + save_folder = opt['save_folder'] + if not osp.exists(save_folder): + os.makedirs(save_folder) + print('mkdir [{:s}] ...'.format(save_folder)) - if opt['dest'] == 'lmdb': - writer = LmdbWriter(save_folder) - else: - writer = FileWriter(save_folder) - - extract_single(opt, writer, split_img) - elif mode == 'pair': - GT_folder = '../../datasets/div2k/DIV2K_train_HR' - LR_folder = '../../datasets/div2k/DIV2K_train_LR_bicubic/X4' - save_GT_folder = '../../datasets/div2k/DIV2K800_sub' - save_LR_folder = '../../datasets/div2k/DIV2K800_sub_bicLRx4' - scale_ratio = 4 - crop_sz = 480 # the size of each sub-image (GT) - step = 240 # step of the sliding crop window (GT) - thres_sz = 48 # size threshold - ######################################################################## - # check that all the GT and LR images have correct scale ratio - img_GT_list = data_util._get_paths_from_images(GT_folder) - img_LR_list = data_util._get_paths_from_images(LR_folder) - assert len(img_GT_list) == len(img_LR_list), 'different length of GT_folder and LR_folder.' - for path_GT, path_LR in zip(img_GT_list, img_LR_list): - img_GT = Image.open(path_GT) - img_LR = Image.open(path_LR) - w_GT, h_GT = img_GT.size - w_LR, h_LR = img_LR.size - assert w_GT / w_LR == scale_ratio, 'GT width [{:d}] is not {:d}X as LR weight [{:d}] for {:s}.'.format( # noqa: E501 - w_GT, scale_ratio, w_LR, path_GT) - assert w_GT / w_LR == scale_ratio, 'GT width [{:d}] is not {:d}X as LR weight [{:d}] for {:s}.'.format( # noqa: E501 - w_GT, scale_ratio, w_LR, path_GT) - # check crop size, step and threshold size - assert crop_sz % scale_ratio == 0, 'crop size is not {:d}X multiplication.'.format( - scale_ratio) - assert step % scale_ratio == 0, 'step is not {:d}X multiplication.'.format(scale_ratio) - assert thres_sz % scale_ratio == 0, 'thres_sz is not {:d}X multiplication.'.format( - scale_ratio) - print('process GT...') - opt['input_folder'] = GT_folder - opt['save_folder'] = save_GT_folder - opt['crop_sz'] = crop_sz - opt['step'] = step - opt['thres_sz'] = thres_sz - extract_single(opt) - print('process LR...') - opt['input_folder'] = LR_folder - opt['save_folder'] = save_LR_folder - opt['crop_sz'] = crop_sz // scale_ratio - opt['step'] = step // scale_ratio - opt['thres_sz'] = thres_sz // scale_ratio - extract_single(opt) - assert len(data_util._get_paths_from_images(save_GT_folder)) == len( - data_util._get_paths_from_images( - save_LR_folder)), 'different length of save_GT_folder and save_LR_folder.' + if opt['dest'] == 'lmdb': + writer = LmdbWriter(save_folder) else: - raise ValueError('Wrong mode.') + writer = FileWriter(save_folder) + + extract_single(opt, writer) class LmdbWriter: @@ -182,26 +133,22 @@ class FileWriter: self.flush() class TiledDataset(data.Dataset): - def __init__(self, opt, split_mode=False): - self.split_mode = split_mode + def __init__(self, opt): + self.split_mode = opt['vertical_split'] self.opt = opt input_folder = opt['input_folder'] self.images = data_util._get_paths_from_images(input_folder) def __getitem__(self, index): if self.split_mode: - return self.get(index, True, True).extend(self.get(index, True, False)) + return (self.get(index, True, True), self.get(index, True, False)) else: - return self.get(index, False, False) - - def get_for_scale(self, img, split_mode, left_image, crop_sz, step, resize_factor, ref_resize_factor): - assert not left_image # Split image not yet supported, False is the default value. + # Wrap in a tuple to align with split mode. + return (self.get(index, False, False), None) + def get_for_scale(self, img, crop_sz, step, resize_factor, ref_resize_factor): thres_sz = self.opt['thres_sz'] - h, w, c = img.shape - if split_mode: - w = w/2 h_space = np.arange(0, h - crop_sz + 1, step) if h - (h_space[-1] + crop_sz) > thres_sz: @@ -231,30 +178,41 @@ class TiledDataset(data.Dataset): def get(self, index, split_mode, left_img): path = self.images[index] img = cv2.imread(path, cv2.IMREAD_UNCHANGED) - - # We must convert the image into a square. Crop the image so that only the center is left, since this is often - # the most salient part of the image. - if len(img.shape) == 2: # Greyscale not supported. - return None h, w, c = img.shape - dim = min(h, w) - img = img[(h - dim) // 2:dim + (h - dim) // 2, (w - dim) // 2:dim + (w - dim) // 2, :] - h, w, c = img.shape # Uncomment to filter any image that doesnt meet a threshold size. if min(h,w) < 1024: return None + # Greyscale not supported. + if len(img.shape) == 2: + return None + + # Handle splitting the image if needed. left = 0 right = w if split_mode: if left_img: left = 0 - right = int(w/2) + right = w//2 else: - left = int(w/2) + left = w//2 right = w img = img[:, left:right] + # We must convert the image into a square. + dim = min(h, w) + if split_mode: + # Crop the image towards the center, which makes more sense in split mode. + if left_img: + img = img[-dim:, -dim:, :] + else: + img = img[:dim, :dim, :] + else: + # Crop the image so that only the center is left, since this is often the most salient part of the image. + img = img[(h - dim) // 2:dim + (h - dim) // 2, (w - dim) // 2:dim + (w - dim) // 2, :] + + h, w, c = img.shape + tile_dim = int(self.opt['crop_sz'][0] * self.opt['resize_final_img'][0]) dsize = (tile_dim, tile_dim) ref_resize_factor = h / tile_dim @@ -266,7 +224,7 @@ class TiledDataset(data.Dataset): results = [(ref_buffer, (-1,-1), (-1,-1))] for crop_sz, resize_factor, step in zip(self.opt['crop_sz'], self.opt['resize_final_img'], self.opt['step']): - results.extend(self.get_for_scale(img, split_mode, left_img, crop_sz, step, resize_factor, ref_resize_factor)) + results.extend(self.get_for_scale(img, crop_sz, step, resize_factor, ref_resize_factor)) return results, path def __len__(self): @@ -276,20 +234,25 @@ class TiledDataset(data.Dataset): def identity(x): return x -def extract_single(opt, writer, split_img=False): - dataset = TiledDataset(opt, split_img) +def extract_single(opt, writer): + dataset = TiledDataset(opt) dataloader = data.DataLoader(dataset, num_workers=opt['n_thread'], collate_fn=identity) tq = tqdm(dataloader) - for imgs in tq: - if imgs is None or imgs[0] is None: + for spl_imgs in tq: + if spl_imgs is None: continue - imgs, path = imgs[0] - if imgs is None or len(imgs) <= 1: - continue - ref_id = writer.write_reference_image(imgs[0], path) - for tile in imgs[1:]: - writer.write_tile_image(ref_id, tile) - writer.flush() + spl_imgs = spl_imgs[0] + for imgs, lbl in zip(list(spl_imgs), ['left', 'right']): + if imgs is None: + continue + imgs, path = imgs + if imgs is None or len(imgs) <= 1: + continue + path = path + "_" + lbl + ref_id = writer.write_reference_image(imgs[0], path) + for tile in imgs[1:]: + writer.write_tile_image(ref_id, tile) + writer.flush() writer.close()