Add extract_subimages_with_ref_lmdb for generating lmdb with reference images

This commit is contained in:
James Betker 2020-09-04 15:30:34 -06:00
parent 696242064c
commit 64a24503f6
8 changed files with 236 additions and 555 deletions

View File

@ -77,7 +77,6 @@ def main():
else:
raise ValueError('Wrong mode.')
def extract_single(opt, split_img=False):
input_folder = opt['input_folder']
save_folder = opt['save_folder']

View File

@ -0,0 +1,236 @@
"""A multi-thread tool to crop large images to sub-images for faster IO."""
import os
import os.path as osp
import numpy as np
import cv2
from PIL import Image
import data.util as data_util # noqa: E402
import lmdb
import pyarrow
import torch.utils.data as data
from tqdm import tqdm
def main():
mode = 'single' # single (one input folder) | pair (extract corresponding GT and LR pairs)
split_img = False
opt = {}
opt['n_thread'] = 12
opt['compression_level'] = 90 # JPEG compression quality rating.
# CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and longer
# compression time. If read raw images during training, use 0 for faster IO speed.
if mode == 'single':
opt['input_folder'] = 'F:\\4k6k\\datasets\\ns_images\\imagesets\\images'
opt['save_folder'] = 'F:\\4k6k\\datasets\\ns_images\\imagesets\\lmdb_with_ref'
opt['crop_sz'] = 512 # the size of each sub-image
opt['step'] = 128 # step of the sliding crop window
opt['thres_sz'] = 128 # size threshold
opt['resize_final_img'] = .5
opt['only_resize'] = False
extract_single(opt, split_img)
elif mode == 'pair':
GT_folder = '../../datasets/div2k/DIV2K_train_HR'
LR_folder = '../../datasets/div2k/DIV2K_train_LR_bicubic/X4'
save_GT_folder = '../../datasets/div2k/DIV2K800_sub'
save_LR_folder = '../../datasets/div2k/DIV2K800_sub_bicLRx4'
scale_ratio = 4
crop_sz = 480 # the size of each sub-image (GT)
step = 240 # step of the sliding crop window (GT)
thres_sz = 48 # size threshold
########################################################################
# check that all the GT and LR images have correct scale ratio
img_GT_list = data_util._get_paths_from_images(GT_folder)
img_LR_list = data_util._get_paths_from_images(LR_folder)
assert len(img_GT_list) == len(img_LR_list), 'different length of GT_folder and LR_folder.'
for path_GT, path_LR in zip(img_GT_list, img_LR_list):
img_GT = Image.open(path_GT)
img_LR = Image.open(path_LR)
w_GT, h_GT = img_GT.size
w_LR, h_LR = img_LR.size
assert w_GT / w_LR == scale_ratio, 'GT width [{:d}] is not {:d}X as LR weight [{:d}] for {:s}.'.format( # noqa: E501
w_GT, scale_ratio, w_LR, path_GT)
assert w_GT / w_LR == scale_ratio, 'GT width [{:d}] is not {:d}X as LR weight [{:d}] for {:s}.'.format( # noqa: E501
w_GT, scale_ratio, w_LR, path_GT)
# check crop size, step and threshold size
assert crop_sz % scale_ratio == 0, 'crop size is not {:d}X multiplication.'.format(
scale_ratio)
assert step % scale_ratio == 0, 'step is not {:d}X multiplication.'.format(scale_ratio)
assert thres_sz % scale_ratio == 0, 'thres_sz is not {:d}X multiplication.'.format(
scale_ratio)
print('process GT...')
opt['input_folder'] = GT_folder
opt['save_folder'] = save_GT_folder
opt['crop_sz'] = crop_sz
opt['step'] = step
opt['thres_sz'] = thres_sz
extract_single(opt)
print('process LR...')
opt['input_folder'] = LR_folder
opt['save_folder'] = save_LR_folder
opt['crop_sz'] = crop_sz // scale_ratio
opt['step'] = step // scale_ratio
opt['thres_sz'] = thres_sz // scale_ratio
extract_single(opt)
assert len(data_util._get_paths_from_images(save_GT_folder)) == len(
data_util._get_paths_from_images(
save_LR_folder)), 'different length of save_GT_folder and save_LR_folder.'
else:
raise ValueError('Wrong mode.')
class LmdbWriter:
def __init__(self, lmdb_path, max_mem_size=30*1024*1024*1024, write_freq=5000):
self.db = lmdb.open(lmdb_path, subdir=True,
map_size=max_mem_size, readonly=False,
meminit=False, map_async=True)
self.txn = self.db.begin(write=True)
self.ref_id = 0
self.tile_ids = {}
self.writes = 0
self.write_freq = write_freq
self.keys = []
# Writes the given reference image to the db and returns its ID.
def write_reference_image(self, ref_img):
id = self.ref_id
self.ref_id += 1
self.write_image(id, ref_img[0], ref_img[1])
return id
# Writes a tile image to the db given a reference image and returns its ID.
def write_tile_image(self, ref_id, tile_image):
next_tile_id = 0 if ref_id not in self.tile_ids.keys() else self.tile_ids[ref_id]
self.tile_ids[ref_id] = next_tile_id+1
full_id = "%i_%i" % (ref_id, next_tile_id)
self.write_image(full_id, tile_image[0], tile_image[1])
self.keys.append(full_id)
return full_id
# Writes an image directly to the db with the given reference image and center point.
def write_image(self, id, img, center_point):
self.txn.put(u'{}'.format(id).encode('ascii'), pyarrow.serialize(img).to_buffer(), pyarrow.serialize(center_point).to_buffer())
self.writes += 1
if self.writes % self.write_freq == 0:
self.txn.commit()
self.txn = self.db.begin(write=True)
def close(self):
self.txn.commit()
with self.db.begin(write=True) as txn:
txn.put(b'__keys__', pyarrow.serialize(self.keys).to_buffer())
txn.put(b'__len__', pyarrow.serialize(len(self.keys)).to_buffer())
self.db.sync()
self.db.close()
class TiledDataset(data.Dataset):
def __init__(self, opt, split_mode=False):
self.split_mode = split_mode
self.opt = opt
input_folder = opt['input_folder']
self.images = data_util._get_paths_from_images(input_folder)
def __getitem__(self, index):
if self.split_mode:
return self.get(index, True, True).extend(self.get(index, True, False))
else:
return self.get(index, False, False)
def get(self, index, split_mode, left_img):
path = self.images[index]
crop_sz = self.opt['crop_sz']
step = self.opt['step']
thres_sz = self.opt['thres_sz']
only_resize = self.opt['only_resize']
img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
# We must convert the image into a square. Crop the image so that only the center is left, since this is often
# the most salient part of the image.
h, w, c = img.shape
dim = min(h, w)
img = img[(h - dim) // 2:dim + (h - dim) // 2, (w - dim) // 2:dim + (w - dim) // 2, :]
h, w, c = img.shape
# Uncomment to filter any image that doesnt meet a threshold size.
if min(h,w) < 1024:
return
left = 0
right = w
if split_mode:
if left_img:
left = 0
right = int(w/2)
else:
left = int(w/2)
right = w
w = int(w/2)
img = img[:, left:right]
h_space = np.arange(0, h - crop_sz + 1, step)
if h - (h_space[-1] + crop_sz) > thres_sz:
h_space = np.append(h_space, h - crop_sz)
w_space = np.arange(0, w - crop_sz + 1, step)
if w - (w_space[-1] + crop_sz) > thres_sz:
w_space = np.append(w_space, w - crop_sz)
dsize = None
if only_resize:
dsize = (crop_sz, crop_sz)
if h < w:
h_space = [0]
w_space = [(w - h) // 2]
crop_sz = h
else:
h_space = [(h - w) // 2]
w_space = [0]
crop_sz = w
index = 0
resize_factor = self.opt['resize_final_img'] if 'resize_final_img' in self.opt.keys() else 1
dsize = (int(crop_sz * resize_factor), int(crop_sz * resize_factor))
# Reference image should always be first.
results = [(cv2.resize(img, dsize, interpolation=cv2.INTER_AREA), (-1,-1))]
for x in h_space:
for y in w_space:
index += 1
crop_img = img[x:x + crop_sz, y:y + crop_sz, :]
center_point = (x + crop_sz // 2, y + crop_sz // 2)
crop_img = np.ascontiguousarray(crop_img)
if 'resize_final_img' in self.opt.keys():
# Resize too.
resize_factor = self.opt['resize_final_img']
center_point = (int(center_point[0] * resize_factor), int(center_point[1] * resize_factor))
crop_img = cv2.resize(crop_img, dsize, interpolation=cv2.INTER_AREA)
success, buffer = cv2.imencode(".jpg", crop_img, [cv2.IMWRITE_JPEG_QUALITY, self.opt['compression_level']])
assert success
results.append((buffer, center_point))
return results
def __len__(self):
return len(self.images)
def identity(x):
return x
def extract_single(opt, split_img=False):
save_folder = opt['save_folder']
if not osp.exists(save_folder):
os.makedirs(save_folder)
print('mkdir [{:s}] ...'.format(save_folder))
lmdb = LmdbWriter(save_folder)
dataset = TiledDataset(opt, split_img)
dataloader = data.DataLoader(dataset, num_workers=opt['n_thread'], collate_fn=identity)
tq = tqdm(dataloader)
for imgs in tq:
if imgs is None or len(imgs) <= 1:
continue
ref_id = lmdb.write_reference_image(imgs[0])
for tile in imgs[1:]:
lmdb.write_tile_image(ref_id, tile)
lmdb.close()
if __name__ == '__main__':
main()

View File

@ -1,49 +0,0 @@
function generate_LR_Vimeo90K()
%% matlab code to genetate bicubic-downsampled for Vimeo90K dataset
up_scale = 4;
mod_scale = 4;
idx = 0;
filepaths = dir('/home/xtwang/datasets/vimeo90k/vimeo_septuplet/sequences/*/*/*.png');
for i = 1 : length(filepaths)
[~,imname,ext] = fileparts(filepaths(i).name);
folder_path = filepaths(i).folder;
save_LR_folder = strrep(folder_path,'vimeo_septuplet','vimeo_septuplet_matlabLRx4');
if ~exist(save_LR_folder, 'dir')
mkdir(save_LR_folder);
end
if isempty(imname)
disp('Ignore . folder.');
elseif strcmp(imname, '.')
disp('Ignore .. folder.');
else
idx = idx + 1;
str_rlt = sprintf('%d\t%s.\n', idx, imname);
fprintf(str_rlt);
% read image
img = imread(fullfile(folder_path, [imname, ext]));
img = im2double(img);
% modcrop
img = modcrop(img, mod_scale);
% LR
im_LR = imresize(img, 1/up_scale, 'bicubic');
if exist('save_LR_folder', 'var')
imwrite(im_LR, fullfile(save_LR_folder, [imname, '.png']));
end
end
end
end
%% modcrop
function img = modcrop(img, modulo)
if size(img,3) == 1
sz = size(img);
sz = sz - mod(sz, modulo);
img = img(1:sz(1), 1:sz(2));
else
tmpsz = size(img);
sz = tmpsz(1:2);
sz = sz - mod(sz, modulo);
img = img(1:sz(1), 1:sz(2),:);
end
end

View File

@ -1,82 +0,0 @@
function generate_mod_LR_bic()
%% matlab code to genetate mod images, bicubic-downsampled LR, bicubic_upsampled images.
%% set parameters
% comment the unnecessary line
input_folder = '../../datasets/DIV2K/DIV2K800';
% save_mod_folder = '';
save_LR_folder = '../../datasets/DIV2K/DIV2K800_bicLRx4';
% save_bic_folder = '';
up_scale = 4;
mod_scale = 4;
if exist('save_mod_folder', 'var')
if exist(save_mod_folder, 'dir')
disp(['It will cover ', save_mod_folder]);
else
mkdir(save_mod_folder);
end
end
if exist('save_LR_folder', 'var')
if exist(save_LR_folder, 'dir')
disp(['It will cover ', save_LR_folder]);
else
mkdir(save_LR_folder);
end
end
if exist('save_bic_folder', 'var')
if exist(save_bic_folder, 'dir')
disp(['It will cover ', save_bic_folder]);
else
mkdir(save_bic_folder);
end
end
idx = 0;
filepaths = dir(fullfile(input_folder,'*.*'));
for i = 1 : length(filepaths)
[paths,imname,ext] = fileparts(filepaths(i).name);
if isempty(imname)
disp('Ignore . folder.');
elseif strcmp(imname, '.')
disp('Ignore .. folder.');
else
idx = idx + 1;
str_rlt = sprintf('%d\t%s.\n', idx, imname);
fprintf(str_rlt);
% read image
img = imread(fullfile(input_folder, [imname, ext]));
img = im2double(img);
% modcrop
img = modcrop(img, mod_scale);
if exist('save_mod_folder', 'var')
imwrite(img, fullfile(save_mod_folder, [imname, '.png']));
end
% LR
im_LR = imresize(img, 1/up_scale, 'bicubic');
if exist('save_LR_folder', 'var')
imwrite(im_LR, fullfile(save_LR_folder, [imname, '.png']));
end
% Bicubic
if exist('save_bic_folder', 'var')
im_B = imresize(im_LR, up_scale, 'bicubic');
imwrite(im_B, fullfile(save_bic_folder, [imname, '.png']));
end
end
end
end
%% modcrop
function img = modcrop(img, modulo)
if size(img,3) == 1
sz = size(img);
sz = sz - mod(sz, modulo);
img = img(1:sz(1), 1:sz(2));
else
tmpsz = size(img);
sz = tmpsz(1:2);
sz = sz - mod(sz, modulo);
img = img(1:sz(1), 1:sz(2),:);
end
end

View File

@ -1,81 +0,0 @@
import os
import sys
import cv2
import numpy as np
try:
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from data.util import imresize_np
except ImportError:
pass
def generate_mod_LR_bic():
# set parameters
up_scale = 4
mod_scale = 4
# set data dir
sourcedir = '/data/datasets/img'
savedir = '/data/datasets/mod'
saveHRpath = os.path.join(savedir, 'HR', 'x' + str(mod_scale))
saveLRpath = os.path.join(savedir, 'LR', 'x' + str(up_scale))
saveBicpath = os.path.join(savedir, 'Bic', 'x' + str(up_scale))
if not os.path.isdir(sourcedir):
print('Error: No source data found')
exit(0)
if not os.path.isdir(savedir):
os.mkdir(savedir)
if not os.path.isdir(os.path.join(savedir, 'HR')):
os.mkdir(os.path.join(savedir, 'HR'))
if not os.path.isdir(os.path.join(savedir, 'LR')):
os.mkdir(os.path.join(savedir, 'LR'))
if not os.path.isdir(os.path.join(savedir, 'Bic')):
os.mkdir(os.path.join(savedir, 'Bic'))
if not os.path.isdir(saveHRpath):
os.mkdir(saveHRpath)
else:
print('It will cover ' + str(saveHRpath))
if not os.path.isdir(saveLRpath):
os.mkdir(saveLRpath)
else:
print('It will cover ' + str(saveLRpath))
if not os.path.isdir(saveBicpath):
os.mkdir(saveBicpath)
else:
print('It will cover ' + str(saveBicpath))
filepaths = [f for f in os.listdir(sourcedir) if f.endswith('.png')]
num_files = len(filepaths)
# prepare data with augementation
for i in range(num_files):
filename = filepaths[i]
print('No.{} -- Processing {}'.format(i, filename))
# read image
image = cv2.imread(os.path.join(sourcedir, filename))
width = int(np.floor(image.shape[1] / mod_scale))
height = int(np.floor(image.shape[0] / mod_scale))
# modcrop
if len(image.shape) == 3:
image_HR = image[0:mod_scale * height, 0:mod_scale * width, :]
else:
image_HR = image[0:mod_scale * height, 0:mod_scale * width]
# LR
image_LR = imresize_np(image_HR, 1 / up_scale, True)
# bic
image_Bic = imresize_np(image_LR, up_scale, True)
cv2.imwrite(os.path.join(saveHRpath, filename), image_HR)
cv2.imwrite(os.path.join(saveLRpath, filename), image_LR)
cv2.imwrite(os.path.join(saveBicpath, filename), image_Bic)
if __name__ == "__main__":
generate_mod_LR_bic()

View File

@ -1,42 +0,0 @@
echo "Prepare DIV2K X4 datasets..."
cd ../../datasets
mkdir DIV2K
cd DIV2K
#### Step 1
echo "Step 1: Download the datasets: [DIV2K_train_HR] and [DIV2K_train_LR_bicubic_X4]..."
# GT
FOLDER=DIV2K_train_HR
FILE=DIV2K_train_HR.zip
if [ ! -d "$FOLDER" ]; then
if [ ! -f "$FILE" ]; then
echo "Downloading $FILE..."
wget http://data.vision.ee.ethz.ch/cvl/DIV2K/$FILE
fi
unzip $FILE
fi
# LR
FOLDER=DIV2K_train_LR_bicubic
FILE=DIV2K_train_LR_bicubic_X4.zip
if [ ! -d "$FOLDER" ]; then
if [ ! -f "$FILE" ]; then
echo "Downloading $FILE..."
wget http://data.vision.ee.ethz.ch/cvl/DIV2K/$FILE
fi
unzip $FILE
fi
#### Step 2
echo "Step 2: Rename the LR images..."
cd ../../codes/data_scripts
python rename.py
#### Step 4
echo "Step 4: Crop to sub-images..."
python extract_subimages.py
#### Step 5
echo "Step5: Create LMDB files..."
python create_lmdb.py

View File

@ -1,11 +0,0 @@
import os
import glob
train_path = '/home/xtwang/datasets/REDS/train_sharp_bicubic/X4'
val_path = '/home/xtwang/datasets/REDS/val_sharp_bicubic/X4'
# mv the val set
val_folders = glob.glob(os.path.join(val_path, '*'))
for folder in val_folders:
new_folder_idx = '{:03d}'.format(int(folder.split('/')[-1]) + 240)
os.system('cp -r {} {}'.format(folder, os.path.join(train_path, new_folder_idx)))

View File

@ -1,289 +0,0 @@
import os
import math
import argparse
import random
import logging
import shutil
from tqdm import tqdm
import torch
from data.data_sampler import DistIterSampler
import options.options as option
from utils import util
from data import create_dataloader, create_dataset
from models import create_model
from time import time
def init_dist(backend='nccl', **kwargs):
# These packages have globals that screw with Windows, so only import them if needed.
import torch.distributed as dist
import torch.multiprocessing as mp
"""initialization for distributed training"""
if mp.get_start_method(allow_none=True) != 'spawn':
mp.set_start_method('spawn')
rank = int(os.environ['RANK'])
num_gpus = torch.cuda.device_count()
torch.cuda.set_device(rank % num_gpus)
dist.init_process_group(backend=backend, **kwargs)
def main():
#### options
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_mi1_spsr_switched2_fullimgref.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()
opt = option.parse(args.opt, is_train=True)
colab_mode = False if 'colab_mode' not in opt.keys() else opt['colab_mode']
if colab_mode:
# Check the configuration of the remote server. Expect models, resume_state, and val_images directories to be there.
# Each one should have a TEST file in it.
util.get_files_from_server(opt['ssh_server'], opt['ssh_username'], opt['ssh_password'],
os.path.join(opt['remote_path'], 'training_state', "TEST"))
util.get_files_from_server(opt['ssh_server'], opt['ssh_username'], opt['ssh_password'],
os.path.join(opt['remote_path'], 'models', "TEST"))
util.get_files_from_server(opt['ssh_server'], opt['ssh_username'], opt['ssh_password'],
os.path.join(opt['remote_path'], 'val_images', "TEST"))
# Load the state and models needed from the remote server.
if opt['path']['resume_state']:
util.get_files_from_server(opt['ssh_server'], opt['ssh_username'], opt['ssh_password'], os.path.join(opt['remote_path'], 'training_state', opt['path']['resume_state']))
if opt['path']['pretrain_model_G']:
util.get_files_from_server(opt['ssh_server'], opt['ssh_username'], opt['ssh_password'], os.path.join(opt['remote_path'], 'models', opt['path']['pretrain_model_G']))
if opt['path']['pretrain_model_D']:
util.get_files_from_server(opt['ssh_server'], opt['ssh_username'], opt['ssh_password'], os.path.join(opt['remote_path'], 'models', opt['path']['pretrain_model_D']))
#### distributed training settings
if args.launcher == 'none': # disabled distributed training
opt['dist'] = False
rank = -1
print('Disabled distributed training.')
else:
opt['dist'] = True
init_dist()
world_size = torch.distributed.get_world_size()
rank = torch.distributed.get_rank()
#### loading resume state if exists
if opt['path'].get('resume_state', None):
# distributed resuming: all load into default GPU
device_id = torch.cuda.current_device()
resume_state = torch.load(opt['path']['resume_state'],
map_location=lambda storage, loc: storage.cuda(device_id))
option.check_resume(opt, resume_state['iter']) # check resume options
else:
resume_state = None
#### mkdir and loggers
if rank <= 0: # normal training (rank -1) OR distributed training (rank 0)
if resume_state is None:
util.mkdir_and_rename(
opt['path']['experiments_root']) # rename experiment folder if exists
util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root'
and 'pretrain_model' not in key and 'resume' not in key))
# config loggers. Before it, the log will not work
util.setup_logger('base', opt['path']['log'], 'train_' + opt['name'], level=logging.INFO,
screen=True, tofile=True)
logger = logging.getLogger('base')
logger.info(option.dict2str(opt))
# tensorboard logger
if opt['use_tb_logger'] and 'debug' not in opt['name']:
tb_logger_path = os.path.join(opt['path']['experiments_root'], 'tb_logger')
version = float(torch.__version__[0:3])
if version >= 1.1: # PyTorch 1.1
from torch.utils.tensorboard import SummaryWriter
else:
logger.info(
'You are using PyTorch {}. Tensorboard will use [tensorboardX]'.format(version))
from tensorboardX import SummaryWriter
tb_logger = SummaryWriter(log_dir=tb_logger_path)
else:
util.setup_logger('base', opt['path']['log'], 'train', level=logging.INFO, screen=True)
logger = logging.getLogger('base')
# convert to NoneDict, which returns None for missing keys
opt = option.dict_to_nonedict(opt)
#### random seed
seed = opt['train']['manual_seed']
if seed is None:
seed = random.randint(1, 10000)
if rank <= 0:
logger.info('Random seed: {}'.format(seed))
util.set_random_seed(seed)
torch.backends.cudnn.benchmark = True
# torch.backends.cudnn.deterministic = True
#### create train and val dataloader
dataset_ratio = 200 # enlarge the size of each epoch
for phase, dataset_opt in opt['datasets'].items():
if phase == 'train':
train_set = create_dataset(dataset_opt)
train_size = int(math.ceil(len(train_set) / dataset_opt['batch_size']))
total_iters = int(opt['train']['niter'])
total_epochs = int(math.ceil(total_iters / train_size))
if opt['dist']:
train_sampler = DistIterSampler(train_set, world_size, rank, dataset_ratio)
total_epochs = int(math.ceil(total_iters / (train_size * dataset_ratio)))
else:
train_sampler = None
train_loader = create_dataloader(train_set, dataset_opt, opt, train_sampler)
if rank <= 0:
logger.info('Number of train images: {:,d}, iters: {:,d}'.format(
len(train_set), train_size))
logger.info('Total epochs needed: {:d} for iters {:,d}'.format(
total_epochs, total_iters))
elif phase == 'val':
val_set = create_dataset(dataset_opt)
val_loader = create_dataloader(val_set, dataset_opt, opt, None)
if rank <= 0:
logger.info('Number of val images in [{:s}]: {:d}'.format(
dataset_opt['name'], len(val_set)))
else:
raise NotImplementedError('Phase [{:s}] is not recognized.'.format(phase))
assert train_loader is not None
#### create model
model = create_model(opt)
#### resume training
if resume_state:
logger.info('Resuming training from epoch: {}, iter: {}.'.format(
resume_state['epoch'], resume_state['iter']))
start_epoch = resume_state['epoch']
current_step = resume_state['iter']
model.resume_training(resume_state) # handle optimizers and schedulers
else:
current_step = -1 if 'start_step' not in opt.keys() else opt['start_step']
start_epoch = 0
#### training
logger.info('Start training from epoch: {:d}, iter: {:d}'.format(start_epoch, current_step))
for epoch in range(start_epoch, total_epochs + 1):
if opt['dist']:
train_sampler.set_epoch(epoch)
tq_ldr = tqdm(train_loader)
_t = time()
_profile = False
for _, train_data in enumerate(tq_ldr):
if _profile:
print("Data fetch: %f" % (time() - _t))
_t = time()
current_step += 1
if current_step > total_iters:
break
#### update learning rate
model.update_learning_rate(current_step, warmup_iter=opt['train']['warmup_iter'])
#### training
if _profile:
print("Update LR: %f" % (time() - _t))
_t = time()
model.feed_data(train_data)
model.optimize_parameters(current_step)
if _profile:
print("Model feed + step: %f" % (time() - _t))
_t = time()
#### log
if current_step % opt['logger']['print_freq'] == 0:
logs = model.get_current_log(current_step)
message = '[epoch:{:3d}, iter:{:8,d}, lr:('.format(epoch, current_step)
for v in model.get_current_learning_rate():
message += '{:.3e},'.format(v)
message += ')] '
for k, v in logs.items():
if 'histogram' in k:
if rank <= 0:
tb_logger.add_histogram(k, v, current_step)
else:
message += '{:s}: {:.4e} '.format(k, v)
# tensorboard logger
if opt['use_tb_logger'] and 'debug' not in opt['name']:
if rank <= 0:
tb_logger.add_scalar(k, v, current_step)
if rank <= 0:
logger.info(message)
#### validation
if opt['datasets'].get('val', None) and current_step % opt['train']['val_freq'] == 0:
if opt['model'] in ['sr', 'srgan', 'corruptgan', 'spsrgan', 'extensibletrainer'] and rank <= 0: # image restoration validation
model.force_restore_swapout()
val_batch_sz = 1 if 'batch_size' not in opt['datasets']['val'].keys() else opt['datasets']['val']['batch_size']
# does not support multi-GPU validation
pbar = util.ProgressBar(len(val_loader) * val_batch_sz)
avg_psnr = 0.
avg_fea_loss = 0.
idx = 0
colab_imgs_to_copy = []
for val_data in val_loader:
idx += 1
for b in range(len(val_data['LQ_path'])):
img_name = os.path.splitext(os.path.basename(val_data['LQ_path'][b]))[0]
img_dir = os.path.join(opt['path']['val_images'], img_name)
util.mkdir(img_dir)
model.feed_data(val_data)
model.test()
visuals = model.get_current_visuals()
if visuals is None:
continue
sr_img = util.tensor2img(visuals['rlt'][b]) # uint8
#gt_img = util.tensor2img(visuals['GT'][b]) # uint8
# Save SR images for reference
img_base_name = '{:s}_{:d}.png'.format(img_name, current_step)
save_img_path = os.path.join(img_dir, img_base_name)
util.save_img(sr_img, save_img_path)
if colab_mode:
colab_imgs_to_copy.append(save_img_path)
# calculate PSNR (Naw - don't do that. PSNR sucks)
#sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale'])
#avg_psnr += util.calculate_psnr(sr_img, gt_img)
#pbar.update('Test {}'.format(img_name))
# calculate fea loss
avg_fea_loss += model.compute_fea_loss(visuals['rlt'][b], visuals['GT'][b])
if colab_mode:
util.copy_files_to_server(opt['ssh_server'], opt['ssh_username'], opt['ssh_password'],
colab_imgs_to_copy,
os.path.join(opt['remote_path'], 'val_images', img_base_name))
avg_psnr = avg_psnr / idx
avg_fea_loss = avg_fea_loss / idx
# log
logger.info('# Validation # PSNR: {:.4e} Fea: {:.4e}'.format(avg_psnr, avg_fea_loss))
# tensorboard logger
if opt['use_tb_logger'] and 'debug' not in opt['name']:
#tb_logger.add_scalar('val_psnr', avg_psnr, current_step)
tb_logger.add_scalar('val_fea', avg_fea_loss, current_step)
#### save models and training states
if current_step % opt['logger']['save_checkpoint_freq'] == 0:
if rank <= 0:
logger.info('Saving models and training states.')
model.save(current_step)
model.save_training_state(epoch, current_step)
if rank <= 0:
logger.info('Saving the final model.')
model.save('latest')
logger.info('End of training.')
tb_logger.close()
if __name__ == '__main__':
main()