From 64a24503f66489929522112a8a81e1c560371584 Mon Sep 17 00:00:00 2001 From: James Betker Date: Fri, 4 Sep 2020 15:30:34 -0600 Subject: [PATCH] Add extract_subimages_with_ref_lmdb for generating lmdb with reference images --- codes/data_scripts/extract_subimages.py | 1 - .../extract_subimages_with_ref_lmdb.py | 236 ++++++++++++++ codes/data_scripts/generate_LR_Vimeo90K.m | 49 --- codes/data_scripts/generate_mod_LR_bic.m | 82 ----- codes/data_scripts/generate_mod_LR_bic.py | 81 ----- .../data_scripts/prepare_DIV2K_x4_dataset.sh | 42 --- codes/data_scripts/regroup_REDS.py | 11 - codes/train2.py | 289 ------------------ 8 files changed, 236 insertions(+), 555 deletions(-) create mode 100644 codes/data_scripts/extract_subimages_with_ref_lmdb.py delete mode 100644 codes/data_scripts/generate_LR_Vimeo90K.m delete mode 100644 codes/data_scripts/generate_mod_LR_bic.m delete mode 100644 codes/data_scripts/generate_mod_LR_bic.py delete mode 100644 codes/data_scripts/prepare_DIV2K_x4_dataset.sh delete mode 100644 codes/data_scripts/regroup_REDS.py delete mode 100644 codes/train2.py diff --git a/codes/data_scripts/extract_subimages.py b/codes/data_scripts/extract_subimages.py index 7d0cbd60..df8d3ef7 100644 --- a/codes/data_scripts/extract_subimages.py +++ b/codes/data_scripts/extract_subimages.py @@ -77,7 +77,6 @@ def main(): else: raise ValueError('Wrong mode.') - def extract_single(opt, split_img=False): input_folder = opt['input_folder'] save_folder = opt['save_folder'] diff --git a/codes/data_scripts/extract_subimages_with_ref_lmdb.py b/codes/data_scripts/extract_subimages_with_ref_lmdb.py new file mode 100644 index 00000000..0a99ac4b --- /dev/null +++ b/codes/data_scripts/extract_subimages_with_ref_lmdb.py @@ -0,0 +1,236 @@ +"""A multi-thread tool to crop large images to sub-images for faster IO.""" +import os +import os.path as osp +import numpy as np +import cv2 +from PIL import Image +import data.util as data_util # noqa: E402 +import lmdb +import pyarrow +import torch.utils.data as data +from tqdm import tqdm + + +def main(): + mode = 'single' # single (one input folder) | pair (extract corresponding GT and LR pairs) + split_img = False + opt = {} + opt['n_thread'] = 12 + opt['compression_level'] = 90 # JPEG compression quality rating. + # CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and longer + # compression time. If read raw images during training, use 0 for faster IO speed. + if mode == 'single': + opt['input_folder'] = 'F:\\4k6k\\datasets\\ns_images\\imagesets\\images' + opt['save_folder'] = 'F:\\4k6k\\datasets\\ns_images\\imagesets\\lmdb_with_ref' + opt['crop_sz'] = 512 # the size of each sub-image + opt['step'] = 128 # step of the sliding crop window + opt['thres_sz'] = 128 # size threshold + opt['resize_final_img'] = .5 + opt['only_resize'] = False + extract_single(opt, split_img) + elif mode == 'pair': + GT_folder = '../../datasets/div2k/DIV2K_train_HR' + LR_folder = '../../datasets/div2k/DIV2K_train_LR_bicubic/X4' + save_GT_folder = '../../datasets/div2k/DIV2K800_sub' + save_LR_folder = '../../datasets/div2k/DIV2K800_sub_bicLRx4' + scale_ratio = 4 + crop_sz = 480 # the size of each sub-image (GT) + step = 240 # step of the sliding crop window (GT) + thres_sz = 48 # size threshold + ######################################################################## + # check that all the GT and LR images have correct scale ratio + img_GT_list = data_util._get_paths_from_images(GT_folder) + img_LR_list = data_util._get_paths_from_images(LR_folder) + assert len(img_GT_list) == len(img_LR_list), 'different length of GT_folder and LR_folder.' + for path_GT, path_LR in zip(img_GT_list, img_LR_list): + img_GT = Image.open(path_GT) + img_LR = Image.open(path_LR) + w_GT, h_GT = img_GT.size + w_LR, h_LR = img_LR.size + assert w_GT / w_LR == scale_ratio, 'GT width [{:d}] is not {:d}X as LR weight [{:d}] for {:s}.'.format( # noqa: E501 + w_GT, scale_ratio, w_LR, path_GT) + assert w_GT / w_LR == scale_ratio, 'GT width [{:d}] is not {:d}X as LR weight [{:d}] for {:s}.'.format( # noqa: E501 + w_GT, scale_ratio, w_LR, path_GT) + # check crop size, step and threshold size + assert crop_sz % scale_ratio == 0, 'crop size is not {:d}X multiplication.'.format( + scale_ratio) + assert step % scale_ratio == 0, 'step is not {:d}X multiplication.'.format(scale_ratio) + assert thres_sz % scale_ratio == 0, 'thres_sz is not {:d}X multiplication.'.format( + scale_ratio) + print('process GT...') + opt['input_folder'] = GT_folder + opt['save_folder'] = save_GT_folder + opt['crop_sz'] = crop_sz + opt['step'] = step + opt['thres_sz'] = thres_sz + extract_single(opt) + print('process LR...') + opt['input_folder'] = LR_folder + opt['save_folder'] = save_LR_folder + opt['crop_sz'] = crop_sz // scale_ratio + opt['step'] = step // scale_ratio + opt['thres_sz'] = thres_sz // scale_ratio + extract_single(opt) + assert len(data_util._get_paths_from_images(save_GT_folder)) == len( + data_util._get_paths_from_images( + save_LR_folder)), 'different length of save_GT_folder and save_LR_folder.' + else: + raise ValueError('Wrong mode.') + + +class LmdbWriter: + def __init__(self, lmdb_path, max_mem_size=30*1024*1024*1024, write_freq=5000): + self.db = lmdb.open(lmdb_path, subdir=True, + map_size=max_mem_size, readonly=False, + meminit=False, map_async=True) + self.txn = self.db.begin(write=True) + self.ref_id = 0 + self.tile_ids = {} + self.writes = 0 + self.write_freq = write_freq + self.keys = [] + + # Writes the given reference image to the db and returns its ID. + def write_reference_image(self, ref_img): + id = self.ref_id + self.ref_id += 1 + self.write_image(id, ref_img[0], ref_img[1]) + return id + + # Writes a tile image to the db given a reference image and returns its ID. + def write_tile_image(self, ref_id, tile_image): + next_tile_id = 0 if ref_id not in self.tile_ids.keys() else self.tile_ids[ref_id] + self.tile_ids[ref_id] = next_tile_id+1 + full_id = "%i_%i" % (ref_id, next_tile_id) + self.write_image(full_id, tile_image[0], tile_image[1]) + self.keys.append(full_id) + return full_id + + # Writes an image directly to the db with the given reference image and center point. + def write_image(self, id, img, center_point): + self.txn.put(u'{}'.format(id).encode('ascii'), pyarrow.serialize(img).to_buffer(), pyarrow.serialize(center_point).to_buffer()) + self.writes += 1 + if self.writes % self.write_freq == 0: + self.txn.commit() + self.txn = self.db.begin(write=True) + + def close(self): + self.txn.commit() + with self.db.begin(write=True) as txn: + txn.put(b'__keys__', pyarrow.serialize(self.keys).to_buffer()) + txn.put(b'__len__', pyarrow.serialize(len(self.keys)).to_buffer()) + self.db.sync() + self.db.close() + + +class TiledDataset(data.Dataset): + def __init__(self, opt, split_mode=False): + self.split_mode = split_mode + self.opt = opt + input_folder = opt['input_folder'] + self.images = data_util._get_paths_from_images(input_folder) + + def __getitem__(self, index): + if self.split_mode: + return self.get(index, True, True).extend(self.get(index, True, False)) + else: + return self.get(index, False, False) + + def get(self, index, split_mode, left_img): + path = self.images[index] + crop_sz = self.opt['crop_sz'] + step = self.opt['step'] + thres_sz = self.opt['thres_sz'] + only_resize = self.opt['only_resize'] + img = cv2.imread(path, cv2.IMREAD_UNCHANGED) + + # We must convert the image into a square. Crop the image so that only the center is left, since this is often + # the most salient part of the image. + h, w, c = img.shape + dim = min(h, w) + img = img[(h - dim) // 2:dim + (h - dim) // 2, (w - dim) // 2:dim + (w - dim) // 2, :] + + h, w, c = img.shape + # Uncomment to filter any image that doesnt meet a threshold size. + if min(h,w) < 1024: + return + left = 0 + right = w + if split_mode: + if left_img: + left = 0 + right = int(w/2) + else: + left = int(w/2) + right = w + w = int(w/2) + img = img[:, left:right] + + h_space = np.arange(0, h - crop_sz + 1, step) + if h - (h_space[-1] + crop_sz) > thres_sz: + h_space = np.append(h_space, h - crop_sz) + w_space = np.arange(0, w - crop_sz + 1, step) + if w - (w_space[-1] + crop_sz) > thres_sz: + w_space = np.append(w_space, w - crop_sz) + + dsize = None + if only_resize: + dsize = (crop_sz, crop_sz) + if h < w: + h_space = [0] + w_space = [(w - h) // 2] + crop_sz = h + else: + h_space = [(h - w) // 2] + w_space = [0] + crop_sz = w + + index = 0 + resize_factor = self.opt['resize_final_img'] if 'resize_final_img' in self.opt.keys() else 1 + dsize = (int(crop_sz * resize_factor), int(crop_sz * resize_factor)) + # Reference image should always be first. + results = [(cv2.resize(img, dsize, interpolation=cv2.INTER_AREA), (-1,-1))] + for x in h_space: + for y in w_space: + index += 1 + crop_img = img[x:x + crop_sz, y:y + crop_sz, :] + center_point = (x + crop_sz // 2, y + crop_sz // 2) + crop_img = np.ascontiguousarray(crop_img) + if 'resize_final_img' in self.opt.keys(): + # Resize too. + resize_factor = self.opt['resize_final_img'] + center_point = (int(center_point[0] * resize_factor), int(center_point[1] * resize_factor)) + crop_img = cv2.resize(crop_img, dsize, interpolation=cv2.INTER_AREA) + success, buffer = cv2.imencode(".jpg", crop_img, [cv2.IMWRITE_JPEG_QUALITY, self.opt['compression_level']]) + assert success + results.append((buffer, center_point)) + return results + + def __len__(self): + return len(self.images) + + +def identity(x): + return x + +def extract_single(opt, split_img=False): + save_folder = opt['save_folder'] + if not osp.exists(save_folder): + os.makedirs(save_folder) + print('mkdir [{:s}] ...'.format(save_folder)) + lmdb = LmdbWriter(save_folder) + + dataset = TiledDataset(opt, split_img) + dataloader = data.DataLoader(dataset, num_workers=opt['n_thread'], collate_fn=identity) + tq = tqdm(dataloader) + for imgs in tq: + if imgs is None or len(imgs) <= 1: + continue + ref_id = lmdb.write_reference_image(imgs[0]) + for tile in imgs[1:]: + lmdb.write_tile_image(ref_id, tile) + lmdb.close() + + +if __name__ == '__main__': + main() diff --git a/codes/data_scripts/generate_LR_Vimeo90K.m b/codes/data_scripts/generate_LR_Vimeo90K.m deleted file mode 100644 index acce7898..00000000 --- a/codes/data_scripts/generate_LR_Vimeo90K.m +++ /dev/null @@ -1,49 +0,0 @@ -function generate_LR_Vimeo90K() -%% matlab code to genetate bicubic-downsampled for Vimeo90K dataset - -up_scale = 4; -mod_scale = 4; -idx = 0; -filepaths = dir('/home/xtwang/datasets/vimeo90k/vimeo_septuplet/sequences/*/*/*.png'); -for i = 1 : length(filepaths) - [~,imname,ext] = fileparts(filepaths(i).name); - folder_path = filepaths(i).folder; - save_LR_folder = strrep(folder_path,'vimeo_septuplet','vimeo_septuplet_matlabLRx4'); - if ~exist(save_LR_folder, 'dir') - mkdir(save_LR_folder); - end - if isempty(imname) - disp('Ignore . folder.'); - elseif strcmp(imname, '.') - disp('Ignore .. folder.'); - else - idx = idx + 1; - str_rlt = sprintf('%d\t%s.\n', idx, imname); - fprintf(str_rlt); - % read image - img = imread(fullfile(folder_path, [imname, ext])); - img = im2double(img); - % modcrop - img = modcrop(img, mod_scale); - % LR - im_LR = imresize(img, 1/up_scale, 'bicubic'); - if exist('save_LR_folder', 'var') - imwrite(im_LR, fullfile(save_LR_folder, [imname, '.png'])); - end - end -end -end - -%% modcrop -function img = modcrop(img, modulo) -if size(img,3) == 1 - sz = size(img); - sz = sz - mod(sz, modulo); - img = img(1:sz(1), 1:sz(2)); -else - tmpsz = size(img); - sz = tmpsz(1:2); - sz = sz - mod(sz, modulo); - img = img(1:sz(1), 1:sz(2),:); -end -end diff --git a/codes/data_scripts/generate_mod_LR_bic.m b/codes/data_scripts/generate_mod_LR_bic.m deleted file mode 100644 index 05a9c61a..00000000 --- a/codes/data_scripts/generate_mod_LR_bic.m +++ /dev/null @@ -1,82 +0,0 @@ -function generate_mod_LR_bic() -%% matlab code to genetate mod images, bicubic-downsampled LR, bicubic_upsampled images. - -%% set parameters -% comment the unnecessary line -input_folder = '../../datasets/DIV2K/DIV2K800'; -% save_mod_folder = ''; -save_LR_folder = '../../datasets/DIV2K/DIV2K800_bicLRx4'; -% save_bic_folder = ''; - -up_scale = 4; -mod_scale = 4; - -if exist('save_mod_folder', 'var') - if exist(save_mod_folder, 'dir') - disp(['It will cover ', save_mod_folder]); - else - mkdir(save_mod_folder); - end -end -if exist('save_LR_folder', 'var') - if exist(save_LR_folder, 'dir') - disp(['It will cover ', save_LR_folder]); - else - mkdir(save_LR_folder); - end -end -if exist('save_bic_folder', 'var') - if exist(save_bic_folder, 'dir') - disp(['It will cover ', save_bic_folder]); - else - mkdir(save_bic_folder); - end -end - -idx = 0; -filepaths = dir(fullfile(input_folder,'*.*')); -for i = 1 : length(filepaths) - [paths,imname,ext] = fileparts(filepaths(i).name); - if isempty(imname) - disp('Ignore . folder.'); - elseif strcmp(imname, '.') - disp('Ignore .. folder.'); - else - idx = idx + 1; - str_rlt = sprintf('%d\t%s.\n', idx, imname); - fprintf(str_rlt); - % read image - img = imread(fullfile(input_folder, [imname, ext])); - img = im2double(img); - % modcrop - img = modcrop(img, mod_scale); - if exist('save_mod_folder', 'var') - imwrite(img, fullfile(save_mod_folder, [imname, '.png'])); - end - % LR - im_LR = imresize(img, 1/up_scale, 'bicubic'); - if exist('save_LR_folder', 'var') - imwrite(im_LR, fullfile(save_LR_folder, [imname, '.png'])); - end - % Bicubic - if exist('save_bic_folder', 'var') - im_B = imresize(im_LR, up_scale, 'bicubic'); - imwrite(im_B, fullfile(save_bic_folder, [imname, '.png'])); - end - end -end -end - -%% modcrop -function img = modcrop(img, modulo) -if size(img,3) == 1 - sz = size(img); - sz = sz - mod(sz, modulo); - img = img(1:sz(1), 1:sz(2)); -else - tmpsz = size(img); - sz = tmpsz(1:2); - sz = sz - mod(sz, modulo); - img = img(1:sz(1), 1:sz(2),:); -end -end diff --git a/codes/data_scripts/generate_mod_LR_bic.py b/codes/data_scripts/generate_mod_LR_bic.py deleted file mode 100644 index 59b313a8..00000000 --- a/codes/data_scripts/generate_mod_LR_bic.py +++ /dev/null @@ -1,81 +0,0 @@ -import os -import sys -import cv2 -import numpy as np - -try: - sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) - from data.util import imresize_np -except ImportError: - pass - - -def generate_mod_LR_bic(): - # set parameters - up_scale = 4 - mod_scale = 4 - # set data dir - sourcedir = '/data/datasets/img' - savedir = '/data/datasets/mod' - - saveHRpath = os.path.join(savedir, 'HR', 'x' + str(mod_scale)) - saveLRpath = os.path.join(savedir, 'LR', 'x' + str(up_scale)) - saveBicpath = os.path.join(savedir, 'Bic', 'x' + str(up_scale)) - - if not os.path.isdir(sourcedir): - print('Error: No source data found') - exit(0) - if not os.path.isdir(savedir): - os.mkdir(savedir) - - if not os.path.isdir(os.path.join(savedir, 'HR')): - os.mkdir(os.path.join(savedir, 'HR')) - if not os.path.isdir(os.path.join(savedir, 'LR')): - os.mkdir(os.path.join(savedir, 'LR')) - if not os.path.isdir(os.path.join(savedir, 'Bic')): - os.mkdir(os.path.join(savedir, 'Bic')) - - if not os.path.isdir(saveHRpath): - os.mkdir(saveHRpath) - else: - print('It will cover ' + str(saveHRpath)) - - if not os.path.isdir(saveLRpath): - os.mkdir(saveLRpath) - else: - print('It will cover ' + str(saveLRpath)) - - if not os.path.isdir(saveBicpath): - os.mkdir(saveBicpath) - else: - print('It will cover ' + str(saveBicpath)) - - filepaths = [f for f in os.listdir(sourcedir) if f.endswith('.png')] - num_files = len(filepaths) - - # prepare data with augementation - for i in range(num_files): - filename = filepaths[i] - print('No.{} -- Processing {}'.format(i, filename)) - # read image - image = cv2.imread(os.path.join(sourcedir, filename)) - - width = int(np.floor(image.shape[1] / mod_scale)) - height = int(np.floor(image.shape[0] / mod_scale)) - # modcrop - if len(image.shape) == 3: - image_HR = image[0:mod_scale * height, 0:mod_scale * width, :] - else: - image_HR = image[0:mod_scale * height, 0:mod_scale * width] - # LR - image_LR = imresize_np(image_HR, 1 / up_scale, True) - # bic - image_Bic = imresize_np(image_LR, up_scale, True) - - cv2.imwrite(os.path.join(saveHRpath, filename), image_HR) - cv2.imwrite(os.path.join(saveLRpath, filename), image_LR) - cv2.imwrite(os.path.join(saveBicpath, filename), image_Bic) - - -if __name__ == "__main__": - generate_mod_LR_bic() diff --git a/codes/data_scripts/prepare_DIV2K_x4_dataset.sh b/codes/data_scripts/prepare_DIV2K_x4_dataset.sh deleted file mode 100644 index a53bd1f0..00000000 --- a/codes/data_scripts/prepare_DIV2K_x4_dataset.sh +++ /dev/null @@ -1,42 +0,0 @@ - - -echo "Prepare DIV2K X4 datasets..." -cd ../../datasets -mkdir DIV2K -cd DIV2K - -#### Step 1 -echo "Step 1: Download the datasets: [DIV2K_train_HR] and [DIV2K_train_LR_bicubic_X4]..." -# GT -FOLDER=DIV2K_train_HR -FILE=DIV2K_train_HR.zip -if [ ! -d "$FOLDER" ]; then - if [ ! -f "$FILE" ]; then - echo "Downloading $FILE..." - wget http://data.vision.ee.ethz.ch/cvl/DIV2K/$FILE - fi - unzip $FILE -fi -# LR -FOLDER=DIV2K_train_LR_bicubic -FILE=DIV2K_train_LR_bicubic_X4.zip -if [ ! -d "$FOLDER" ]; then - if [ ! -f "$FILE" ]; then - echo "Downloading $FILE..." - wget http://data.vision.ee.ethz.ch/cvl/DIV2K/$FILE - fi - unzip $FILE -fi - -#### Step 2 -echo "Step 2: Rename the LR images..." -cd ../../codes/data_scripts -python rename.py - -#### Step 4 -echo "Step 4: Crop to sub-images..." -python extract_subimages.py - -#### Step 5 -echo "Step5: Create LMDB files..." -python create_lmdb.py diff --git a/codes/data_scripts/regroup_REDS.py b/codes/data_scripts/regroup_REDS.py deleted file mode 100644 index 7c8fa928..00000000 --- a/codes/data_scripts/regroup_REDS.py +++ /dev/null @@ -1,11 +0,0 @@ -import os -import glob - -train_path = '/home/xtwang/datasets/REDS/train_sharp_bicubic/X4' -val_path = '/home/xtwang/datasets/REDS/val_sharp_bicubic/X4' - -# mv the val set -val_folders = glob.glob(os.path.join(val_path, '*')) -for folder in val_folders: - new_folder_idx = '{:03d}'.format(int(folder.split('/')[-1]) + 240) - os.system('cp -r {} {}'.format(folder, os.path.join(train_path, new_folder_idx))) diff --git a/codes/train2.py b/codes/train2.py deleted file mode 100644 index 832d68bb..00000000 --- a/codes/train2.py +++ /dev/null @@ -1,289 +0,0 @@ -import os -import math -import argparse -import random -import logging -import shutil -from tqdm import tqdm - -import torch -from data.data_sampler import DistIterSampler - -import options.options as option -from utils import util -from data import create_dataloader, create_dataset -from models import create_model -from time import time - - -def init_dist(backend='nccl', **kwargs): - # These packages have globals that screw with Windows, so only import them if needed. - import torch.distributed as dist - import torch.multiprocessing as mp - - """initialization for distributed training""" - if mp.get_start_method(allow_none=True) != 'spawn': - mp.set_start_method('spawn') - rank = int(os.environ['RANK']) - num_gpus = torch.cuda.device_count() - torch.cuda.set_device(rank % num_gpus) - dist.init_process_group(backend=backend, **kwargs) - -def main(): - #### options - parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_mi1_spsr_switched2_fullimgref.yml') - parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', - help='job launcher') - parser.add_argument('--local_rank', type=int, default=0) - args = parser.parse_args() - opt = option.parse(args.opt, is_train=True) - - colab_mode = False if 'colab_mode' not in opt.keys() else opt['colab_mode'] - if colab_mode: - # Check the configuration of the remote server. Expect models, resume_state, and val_images directories to be there. - # Each one should have a TEST file in it. - util.get_files_from_server(opt['ssh_server'], opt['ssh_username'], opt['ssh_password'], - os.path.join(opt['remote_path'], 'training_state', "TEST")) - util.get_files_from_server(opt['ssh_server'], opt['ssh_username'], opt['ssh_password'], - os.path.join(opt['remote_path'], 'models', "TEST")) - util.get_files_from_server(opt['ssh_server'], opt['ssh_username'], opt['ssh_password'], - os.path.join(opt['remote_path'], 'val_images', "TEST")) - # Load the state and models needed from the remote server. - if opt['path']['resume_state']: - util.get_files_from_server(opt['ssh_server'], opt['ssh_username'], opt['ssh_password'], os.path.join(opt['remote_path'], 'training_state', opt['path']['resume_state'])) - if opt['path']['pretrain_model_G']: - util.get_files_from_server(opt['ssh_server'], opt['ssh_username'], opt['ssh_password'], os.path.join(opt['remote_path'], 'models', opt['path']['pretrain_model_G'])) - if opt['path']['pretrain_model_D']: - util.get_files_from_server(opt['ssh_server'], opt['ssh_username'], opt['ssh_password'], os.path.join(opt['remote_path'], 'models', opt['path']['pretrain_model_D'])) - - #### distributed training settings - if args.launcher == 'none': # disabled distributed training - opt['dist'] = False - rank = -1 - print('Disabled distributed training.') - else: - opt['dist'] = True - init_dist() - world_size = torch.distributed.get_world_size() - rank = torch.distributed.get_rank() - - #### loading resume state if exists - if opt['path'].get('resume_state', None): - # distributed resuming: all load into default GPU - device_id = torch.cuda.current_device() - resume_state = torch.load(opt['path']['resume_state'], - map_location=lambda storage, loc: storage.cuda(device_id)) - option.check_resume(opt, resume_state['iter']) # check resume options - else: - resume_state = None - - #### mkdir and loggers - if rank <= 0: # normal training (rank -1) OR distributed training (rank 0) - if resume_state is None: - util.mkdir_and_rename( - opt['path']['experiments_root']) # rename experiment folder if exists - util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root' - and 'pretrain_model' not in key and 'resume' not in key)) - - # config loggers. Before it, the log will not work - util.setup_logger('base', opt['path']['log'], 'train_' + opt['name'], level=logging.INFO, - screen=True, tofile=True) - logger = logging.getLogger('base') - logger.info(option.dict2str(opt)) - # tensorboard logger - if opt['use_tb_logger'] and 'debug' not in opt['name']: - tb_logger_path = os.path.join(opt['path']['experiments_root'], 'tb_logger') - version = float(torch.__version__[0:3]) - if version >= 1.1: # PyTorch 1.1 - from torch.utils.tensorboard import SummaryWriter - else: - logger.info( - 'You are using PyTorch {}. Tensorboard will use [tensorboardX]'.format(version)) - from tensorboardX import SummaryWriter - tb_logger = SummaryWriter(log_dir=tb_logger_path) - else: - util.setup_logger('base', opt['path']['log'], 'train', level=logging.INFO, screen=True) - logger = logging.getLogger('base') - - # convert to NoneDict, which returns None for missing keys - opt = option.dict_to_nonedict(opt) - - #### random seed - seed = opt['train']['manual_seed'] - if seed is None: - seed = random.randint(1, 10000) - if rank <= 0: - logger.info('Random seed: {}'.format(seed)) - util.set_random_seed(seed) - - torch.backends.cudnn.benchmark = True - # torch.backends.cudnn.deterministic = True - - #### create train and val dataloader - dataset_ratio = 200 # enlarge the size of each epoch - for phase, dataset_opt in opt['datasets'].items(): - if phase == 'train': - train_set = create_dataset(dataset_opt) - train_size = int(math.ceil(len(train_set) / dataset_opt['batch_size'])) - total_iters = int(opt['train']['niter']) - total_epochs = int(math.ceil(total_iters / train_size)) - if opt['dist']: - train_sampler = DistIterSampler(train_set, world_size, rank, dataset_ratio) - total_epochs = int(math.ceil(total_iters / (train_size * dataset_ratio))) - else: - train_sampler = None - train_loader = create_dataloader(train_set, dataset_opt, opt, train_sampler) - if rank <= 0: - logger.info('Number of train images: {:,d}, iters: {:,d}'.format( - len(train_set), train_size)) - logger.info('Total epochs needed: {:d} for iters {:,d}'.format( - total_epochs, total_iters)) - elif phase == 'val': - val_set = create_dataset(dataset_opt) - val_loader = create_dataloader(val_set, dataset_opt, opt, None) - if rank <= 0: - logger.info('Number of val images in [{:s}]: {:d}'.format( - dataset_opt['name'], len(val_set))) - else: - raise NotImplementedError('Phase [{:s}] is not recognized.'.format(phase)) - assert train_loader is not None - - #### create model - model = create_model(opt) - - #### resume training - if resume_state: - logger.info('Resuming training from epoch: {}, iter: {}.'.format( - resume_state['epoch'], resume_state['iter'])) - - start_epoch = resume_state['epoch'] - current_step = resume_state['iter'] - model.resume_training(resume_state) # handle optimizers and schedulers - else: - current_step = -1 if 'start_step' not in opt.keys() else opt['start_step'] - start_epoch = 0 - - #### training - logger.info('Start training from epoch: {:d}, iter: {:d}'.format(start_epoch, current_step)) - for epoch in range(start_epoch, total_epochs + 1): - if opt['dist']: - train_sampler.set_epoch(epoch) - tq_ldr = tqdm(train_loader) - - _t = time() - _profile = False - for _, train_data in enumerate(tq_ldr): - if _profile: - print("Data fetch: %f" % (time() - _t)) - _t = time() - - current_step += 1 - if current_step > total_iters: - break - #### update learning rate - model.update_learning_rate(current_step, warmup_iter=opt['train']['warmup_iter']) - - #### training - if _profile: - print("Update LR: %f" % (time() - _t)) - _t = time() - model.feed_data(train_data) - model.optimize_parameters(current_step) - if _profile: - print("Model feed + step: %f" % (time() - _t)) - _t = time() - - #### log - if current_step % opt['logger']['print_freq'] == 0: - logs = model.get_current_log(current_step) - message = '[epoch:{:3d}, iter:{:8,d}, lr:('.format(epoch, current_step) - for v in model.get_current_learning_rate(): - message += '{:.3e},'.format(v) - message += ')] ' - for k, v in logs.items(): - if 'histogram' in k: - if rank <= 0: - tb_logger.add_histogram(k, v, current_step) - else: - message += '{:s}: {:.4e} '.format(k, v) - # tensorboard logger - if opt['use_tb_logger'] and 'debug' not in opt['name']: - if rank <= 0: - tb_logger.add_scalar(k, v, current_step) - if rank <= 0: - logger.info(message) - #### validation - if opt['datasets'].get('val', None) and current_step % opt['train']['val_freq'] == 0: - if opt['model'] in ['sr', 'srgan', 'corruptgan', 'spsrgan', 'extensibletrainer'] and rank <= 0: # image restoration validation - model.force_restore_swapout() - val_batch_sz = 1 if 'batch_size' not in opt['datasets']['val'].keys() else opt['datasets']['val']['batch_size'] - # does not support multi-GPU validation - pbar = util.ProgressBar(len(val_loader) * val_batch_sz) - avg_psnr = 0. - avg_fea_loss = 0. - idx = 0 - colab_imgs_to_copy = [] - for val_data in val_loader: - idx += 1 - for b in range(len(val_data['LQ_path'])): - img_name = os.path.splitext(os.path.basename(val_data['LQ_path'][b]))[0] - img_dir = os.path.join(opt['path']['val_images'], img_name) - util.mkdir(img_dir) - - model.feed_data(val_data) - model.test() - - visuals = model.get_current_visuals() - if visuals is None: - continue - - sr_img = util.tensor2img(visuals['rlt'][b]) # uint8 - #gt_img = util.tensor2img(visuals['GT'][b]) # uint8 - - # Save SR images for reference - img_base_name = '{:s}_{:d}.png'.format(img_name, current_step) - save_img_path = os.path.join(img_dir, img_base_name) - util.save_img(sr_img, save_img_path) - if colab_mode: - colab_imgs_to_copy.append(save_img_path) - - # calculate PSNR (Naw - don't do that. PSNR sucks) - #sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale']) - #avg_psnr += util.calculate_psnr(sr_img, gt_img) - #pbar.update('Test {}'.format(img_name)) - - # calculate fea loss - avg_fea_loss += model.compute_fea_loss(visuals['rlt'][b], visuals['GT'][b]) - - if colab_mode: - util.copy_files_to_server(opt['ssh_server'], opt['ssh_username'], opt['ssh_password'], - colab_imgs_to_copy, - os.path.join(opt['remote_path'], 'val_images', img_base_name)) - - avg_psnr = avg_psnr / idx - avg_fea_loss = avg_fea_loss / idx - - # log - logger.info('# Validation # PSNR: {:.4e} Fea: {:.4e}'.format(avg_psnr, avg_fea_loss)) - # tensorboard logger - if opt['use_tb_logger'] and 'debug' not in opt['name']: - #tb_logger.add_scalar('val_psnr', avg_psnr, current_step) - tb_logger.add_scalar('val_fea', avg_fea_loss, current_step) - - #### save models and training states - if current_step % opt['logger']['save_checkpoint_freq'] == 0: - if rank <= 0: - logger.info('Saving models and training states.') - model.save(current_step) - model.save_training_state(epoch, current_step) - - if rank <= 0: - logger.info('Saving the final model.') - model.save('latest') - logger.info('End of training.') - tb_logger.close() - - -if __name__ == '__main__': - main()