forked from mrq/DL-Art-School
Add extract_subimages_with_ref_lmdb for generating lmdb with reference images
This commit is contained in:
parent
696242064c
commit
64a24503f6
|
@ -77,7 +77,6 @@ def main():
|
|||
else:
|
||||
raise ValueError('Wrong mode.')
|
||||
|
||||
|
||||
def extract_single(opt, split_img=False):
|
||||
input_folder = opt['input_folder']
|
||||
save_folder = opt['save_folder']
|
||||
|
|
236
codes/data_scripts/extract_subimages_with_ref_lmdb.py
Normal file
236
codes/data_scripts/extract_subimages_with_ref_lmdb.py
Normal file
|
@ -0,0 +1,236 @@
|
|||
"""A multi-thread tool to crop large images to sub-images for faster IO."""
|
||||
import os
|
||||
import os.path as osp
|
||||
import numpy as np
|
||||
import cv2
|
||||
from PIL import Image
|
||||
import data.util as data_util # noqa: E402
|
||||
import lmdb
|
||||
import pyarrow
|
||||
import torch.utils.data as data
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def main():
|
||||
mode = 'single' # single (one input folder) | pair (extract corresponding GT and LR pairs)
|
||||
split_img = False
|
||||
opt = {}
|
||||
opt['n_thread'] = 12
|
||||
opt['compression_level'] = 90 # JPEG compression quality rating.
|
||||
# CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and longer
|
||||
# compression time. If read raw images during training, use 0 for faster IO speed.
|
||||
if mode == 'single':
|
||||
opt['input_folder'] = 'F:\\4k6k\\datasets\\ns_images\\imagesets\\images'
|
||||
opt['save_folder'] = 'F:\\4k6k\\datasets\\ns_images\\imagesets\\lmdb_with_ref'
|
||||
opt['crop_sz'] = 512 # the size of each sub-image
|
||||
opt['step'] = 128 # step of the sliding crop window
|
||||
opt['thres_sz'] = 128 # size threshold
|
||||
opt['resize_final_img'] = .5
|
||||
opt['only_resize'] = False
|
||||
extract_single(opt, split_img)
|
||||
elif mode == 'pair':
|
||||
GT_folder = '../../datasets/div2k/DIV2K_train_HR'
|
||||
LR_folder = '../../datasets/div2k/DIV2K_train_LR_bicubic/X4'
|
||||
save_GT_folder = '../../datasets/div2k/DIV2K800_sub'
|
||||
save_LR_folder = '../../datasets/div2k/DIV2K800_sub_bicLRx4'
|
||||
scale_ratio = 4
|
||||
crop_sz = 480 # the size of each sub-image (GT)
|
||||
step = 240 # step of the sliding crop window (GT)
|
||||
thres_sz = 48 # size threshold
|
||||
########################################################################
|
||||
# check that all the GT and LR images have correct scale ratio
|
||||
img_GT_list = data_util._get_paths_from_images(GT_folder)
|
||||
img_LR_list = data_util._get_paths_from_images(LR_folder)
|
||||
assert len(img_GT_list) == len(img_LR_list), 'different length of GT_folder and LR_folder.'
|
||||
for path_GT, path_LR in zip(img_GT_list, img_LR_list):
|
||||
img_GT = Image.open(path_GT)
|
||||
img_LR = Image.open(path_LR)
|
||||
w_GT, h_GT = img_GT.size
|
||||
w_LR, h_LR = img_LR.size
|
||||
assert w_GT / w_LR == scale_ratio, 'GT width [{:d}] is not {:d}X as LR weight [{:d}] for {:s}.'.format( # noqa: E501
|
||||
w_GT, scale_ratio, w_LR, path_GT)
|
||||
assert w_GT / w_LR == scale_ratio, 'GT width [{:d}] is not {:d}X as LR weight [{:d}] for {:s}.'.format( # noqa: E501
|
||||
w_GT, scale_ratio, w_LR, path_GT)
|
||||
# check crop size, step and threshold size
|
||||
assert crop_sz % scale_ratio == 0, 'crop size is not {:d}X multiplication.'.format(
|
||||
scale_ratio)
|
||||
assert step % scale_ratio == 0, 'step is not {:d}X multiplication.'.format(scale_ratio)
|
||||
assert thres_sz % scale_ratio == 0, 'thres_sz is not {:d}X multiplication.'.format(
|
||||
scale_ratio)
|
||||
print('process GT...')
|
||||
opt['input_folder'] = GT_folder
|
||||
opt['save_folder'] = save_GT_folder
|
||||
opt['crop_sz'] = crop_sz
|
||||
opt['step'] = step
|
||||
opt['thres_sz'] = thres_sz
|
||||
extract_single(opt)
|
||||
print('process LR...')
|
||||
opt['input_folder'] = LR_folder
|
||||
opt['save_folder'] = save_LR_folder
|
||||
opt['crop_sz'] = crop_sz // scale_ratio
|
||||
opt['step'] = step // scale_ratio
|
||||
opt['thres_sz'] = thres_sz // scale_ratio
|
||||
extract_single(opt)
|
||||
assert len(data_util._get_paths_from_images(save_GT_folder)) == len(
|
||||
data_util._get_paths_from_images(
|
||||
save_LR_folder)), 'different length of save_GT_folder and save_LR_folder.'
|
||||
else:
|
||||
raise ValueError('Wrong mode.')
|
||||
|
||||
|
||||
class LmdbWriter:
|
||||
def __init__(self, lmdb_path, max_mem_size=30*1024*1024*1024, write_freq=5000):
|
||||
self.db = lmdb.open(lmdb_path, subdir=True,
|
||||
map_size=max_mem_size, readonly=False,
|
||||
meminit=False, map_async=True)
|
||||
self.txn = self.db.begin(write=True)
|
||||
self.ref_id = 0
|
||||
self.tile_ids = {}
|
||||
self.writes = 0
|
||||
self.write_freq = write_freq
|
||||
self.keys = []
|
||||
|
||||
# Writes the given reference image to the db and returns its ID.
|
||||
def write_reference_image(self, ref_img):
|
||||
id = self.ref_id
|
||||
self.ref_id += 1
|
||||
self.write_image(id, ref_img[0], ref_img[1])
|
||||
return id
|
||||
|
||||
# Writes a tile image to the db given a reference image and returns its ID.
|
||||
def write_tile_image(self, ref_id, tile_image):
|
||||
next_tile_id = 0 if ref_id not in self.tile_ids.keys() else self.tile_ids[ref_id]
|
||||
self.tile_ids[ref_id] = next_tile_id+1
|
||||
full_id = "%i_%i" % (ref_id, next_tile_id)
|
||||
self.write_image(full_id, tile_image[0], tile_image[1])
|
||||
self.keys.append(full_id)
|
||||
return full_id
|
||||
|
||||
# Writes an image directly to the db with the given reference image and center point.
|
||||
def write_image(self, id, img, center_point):
|
||||
self.txn.put(u'{}'.format(id).encode('ascii'), pyarrow.serialize(img).to_buffer(), pyarrow.serialize(center_point).to_buffer())
|
||||
self.writes += 1
|
||||
if self.writes % self.write_freq == 0:
|
||||
self.txn.commit()
|
||||
self.txn = self.db.begin(write=True)
|
||||
|
||||
def close(self):
|
||||
self.txn.commit()
|
||||
with self.db.begin(write=True) as txn:
|
||||
txn.put(b'__keys__', pyarrow.serialize(self.keys).to_buffer())
|
||||
txn.put(b'__len__', pyarrow.serialize(len(self.keys)).to_buffer())
|
||||
self.db.sync()
|
||||
self.db.close()
|
||||
|
||||
|
||||
class TiledDataset(data.Dataset):
|
||||
def __init__(self, opt, split_mode=False):
|
||||
self.split_mode = split_mode
|
||||
self.opt = opt
|
||||
input_folder = opt['input_folder']
|
||||
self.images = data_util._get_paths_from_images(input_folder)
|
||||
|
||||
def __getitem__(self, index):
|
||||
if self.split_mode:
|
||||
return self.get(index, True, True).extend(self.get(index, True, False))
|
||||
else:
|
||||
return self.get(index, False, False)
|
||||
|
||||
def get(self, index, split_mode, left_img):
|
||||
path = self.images[index]
|
||||
crop_sz = self.opt['crop_sz']
|
||||
step = self.opt['step']
|
||||
thres_sz = self.opt['thres_sz']
|
||||
only_resize = self.opt['only_resize']
|
||||
img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
|
||||
|
||||
# We must convert the image into a square. Crop the image so that only the center is left, since this is often
|
||||
# the most salient part of the image.
|
||||
h, w, c = img.shape
|
||||
dim = min(h, w)
|
||||
img = img[(h - dim) // 2:dim + (h - dim) // 2, (w - dim) // 2:dim + (w - dim) // 2, :]
|
||||
|
||||
h, w, c = img.shape
|
||||
# Uncomment to filter any image that doesnt meet a threshold size.
|
||||
if min(h,w) < 1024:
|
||||
return
|
||||
left = 0
|
||||
right = w
|
||||
if split_mode:
|
||||
if left_img:
|
||||
left = 0
|
||||
right = int(w/2)
|
||||
else:
|
||||
left = int(w/2)
|
||||
right = w
|
||||
w = int(w/2)
|
||||
img = img[:, left:right]
|
||||
|
||||
h_space = np.arange(0, h - crop_sz + 1, step)
|
||||
if h - (h_space[-1] + crop_sz) > thres_sz:
|
||||
h_space = np.append(h_space, h - crop_sz)
|
||||
w_space = np.arange(0, w - crop_sz + 1, step)
|
||||
if w - (w_space[-1] + crop_sz) > thres_sz:
|
||||
w_space = np.append(w_space, w - crop_sz)
|
||||
|
||||
dsize = None
|
||||
if only_resize:
|
||||
dsize = (crop_sz, crop_sz)
|
||||
if h < w:
|
||||
h_space = [0]
|
||||
w_space = [(w - h) // 2]
|
||||
crop_sz = h
|
||||
else:
|
||||
h_space = [(h - w) // 2]
|
||||
w_space = [0]
|
||||
crop_sz = w
|
||||
|
||||
index = 0
|
||||
resize_factor = self.opt['resize_final_img'] if 'resize_final_img' in self.opt.keys() else 1
|
||||
dsize = (int(crop_sz * resize_factor), int(crop_sz * resize_factor))
|
||||
# Reference image should always be first.
|
||||
results = [(cv2.resize(img, dsize, interpolation=cv2.INTER_AREA), (-1,-1))]
|
||||
for x in h_space:
|
||||
for y in w_space:
|
||||
index += 1
|
||||
crop_img = img[x:x + crop_sz, y:y + crop_sz, :]
|
||||
center_point = (x + crop_sz // 2, y + crop_sz // 2)
|
||||
crop_img = np.ascontiguousarray(crop_img)
|
||||
if 'resize_final_img' in self.opt.keys():
|
||||
# Resize too.
|
||||
resize_factor = self.opt['resize_final_img']
|
||||
center_point = (int(center_point[0] * resize_factor), int(center_point[1] * resize_factor))
|
||||
crop_img = cv2.resize(crop_img, dsize, interpolation=cv2.INTER_AREA)
|
||||
success, buffer = cv2.imencode(".jpg", crop_img, [cv2.IMWRITE_JPEG_QUALITY, self.opt['compression_level']])
|
||||
assert success
|
||||
results.append((buffer, center_point))
|
||||
return results
|
||||
|
||||
def __len__(self):
|
||||
return len(self.images)
|
||||
|
||||
|
||||
def identity(x):
|
||||
return x
|
||||
|
||||
def extract_single(opt, split_img=False):
|
||||
save_folder = opt['save_folder']
|
||||
if not osp.exists(save_folder):
|
||||
os.makedirs(save_folder)
|
||||
print('mkdir [{:s}] ...'.format(save_folder))
|
||||
lmdb = LmdbWriter(save_folder)
|
||||
|
||||
dataset = TiledDataset(opt, split_img)
|
||||
dataloader = data.DataLoader(dataset, num_workers=opt['n_thread'], collate_fn=identity)
|
||||
tq = tqdm(dataloader)
|
||||
for imgs in tq:
|
||||
if imgs is None or len(imgs) <= 1:
|
||||
continue
|
||||
ref_id = lmdb.write_reference_image(imgs[0])
|
||||
for tile in imgs[1:]:
|
||||
lmdb.write_tile_image(ref_id, tile)
|
||||
lmdb.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -1,49 +0,0 @@
|
|||
function generate_LR_Vimeo90K()
|
||||
%% matlab code to genetate bicubic-downsampled for Vimeo90K dataset
|
||||
|
||||
up_scale = 4;
|
||||
mod_scale = 4;
|
||||
idx = 0;
|
||||
filepaths = dir('/home/xtwang/datasets/vimeo90k/vimeo_septuplet/sequences/*/*/*.png');
|
||||
for i = 1 : length(filepaths)
|
||||
[~,imname,ext] = fileparts(filepaths(i).name);
|
||||
folder_path = filepaths(i).folder;
|
||||
save_LR_folder = strrep(folder_path,'vimeo_septuplet','vimeo_septuplet_matlabLRx4');
|
||||
if ~exist(save_LR_folder, 'dir')
|
||||
mkdir(save_LR_folder);
|
||||
end
|
||||
if isempty(imname)
|
||||
disp('Ignore . folder.');
|
||||
elseif strcmp(imname, '.')
|
||||
disp('Ignore .. folder.');
|
||||
else
|
||||
idx = idx + 1;
|
||||
str_rlt = sprintf('%d\t%s.\n', idx, imname);
|
||||
fprintf(str_rlt);
|
||||
% read image
|
||||
img = imread(fullfile(folder_path, [imname, ext]));
|
||||
img = im2double(img);
|
||||
% modcrop
|
||||
img = modcrop(img, mod_scale);
|
||||
% LR
|
||||
im_LR = imresize(img, 1/up_scale, 'bicubic');
|
||||
if exist('save_LR_folder', 'var')
|
||||
imwrite(im_LR, fullfile(save_LR_folder, [imname, '.png']));
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
%% modcrop
|
||||
function img = modcrop(img, modulo)
|
||||
if size(img,3) == 1
|
||||
sz = size(img);
|
||||
sz = sz - mod(sz, modulo);
|
||||
img = img(1:sz(1), 1:sz(2));
|
||||
else
|
||||
tmpsz = size(img);
|
||||
sz = tmpsz(1:2);
|
||||
sz = sz - mod(sz, modulo);
|
||||
img = img(1:sz(1), 1:sz(2),:);
|
||||
end
|
||||
end
|
|
@ -1,82 +0,0 @@
|
|||
function generate_mod_LR_bic()
|
||||
%% matlab code to genetate mod images, bicubic-downsampled LR, bicubic_upsampled images.
|
||||
|
||||
%% set parameters
|
||||
% comment the unnecessary line
|
||||
input_folder = '../../datasets/DIV2K/DIV2K800';
|
||||
% save_mod_folder = '';
|
||||
save_LR_folder = '../../datasets/DIV2K/DIV2K800_bicLRx4';
|
||||
% save_bic_folder = '';
|
||||
|
||||
up_scale = 4;
|
||||
mod_scale = 4;
|
||||
|
||||
if exist('save_mod_folder', 'var')
|
||||
if exist(save_mod_folder, 'dir')
|
||||
disp(['It will cover ', save_mod_folder]);
|
||||
else
|
||||
mkdir(save_mod_folder);
|
||||
end
|
||||
end
|
||||
if exist('save_LR_folder', 'var')
|
||||
if exist(save_LR_folder, 'dir')
|
||||
disp(['It will cover ', save_LR_folder]);
|
||||
else
|
||||
mkdir(save_LR_folder);
|
||||
end
|
||||
end
|
||||
if exist('save_bic_folder', 'var')
|
||||
if exist(save_bic_folder, 'dir')
|
||||
disp(['It will cover ', save_bic_folder]);
|
||||
else
|
||||
mkdir(save_bic_folder);
|
||||
end
|
||||
end
|
||||
|
||||
idx = 0;
|
||||
filepaths = dir(fullfile(input_folder,'*.*'));
|
||||
for i = 1 : length(filepaths)
|
||||
[paths,imname,ext] = fileparts(filepaths(i).name);
|
||||
if isempty(imname)
|
||||
disp('Ignore . folder.');
|
||||
elseif strcmp(imname, '.')
|
||||
disp('Ignore .. folder.');
|
||||
else
|
||||
idx = idx + 1;
|
||||
str_rlt = sprintf('%d\t%s.\n', idx, imname);
|
||||
fprintf(str_rlt);
|
||||
% read image
|
||||
img = imread(fullfile(input_folder, [imname, ext]));
|
||||
img = im2double(img);
|
||||
% modcrop
|
||||
img = modcrop(img, mod_scale);
|
||||
if exist('save_mod_folder', 'var')
|
||||
imwrite(img, fullfile(save_mod_folder, [imname, '.png']));
|
||||
end
|
||||
% LR
|
||||
im_LR = imresize(img, 1/up_scale, 'bicubic');
|
||||
if exist('save_LR_folder', 'var')
|
||||
imwrite(im_LR, fullfile(save_LR_folder, [imname, '.png']));
|
||||
end
|
||||
% Bicubic
|
||||
if exist('save_bic_folder', 'var')
|
||||
im_B = imresize(im_LR, up_scale, 'bicubic');
|
||||
imwrite(im_B, fullfile(save_bic_folder, [imname, '.png']));
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
%% modcrop
|
||||
function img = modcrop(img, modulo)
|
||||
if size(img,3) == 1
|
||||
sz = size(img);
|
||||
sz = sz - mod(sz, modulo);
|
||||
img = img(1:sz(1), 1:sz(2));
|
||||
else
|
||||
tmpsz = size(img);
|
||||
sz = tmpsz(1:2);
|
||||
sz = sz - mod(sz, modulo);
|
||||
img = img(1:sz(1), 1:sz(2),:);
|
||||
end
|
||||
end
|
|
@ -1,81 +0,0 @@
|
|||
import os
|
||||
import sys
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
from data.util import imresize_np
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
def generate_mod_LR_bic():
|
||||
# set parameters
|
||||
up_scale = 4
|
||||
mod_scale = 4
|
||||
# set data dir
|
||||
sourcedir = '/data/datasets/img'
|
||||
savedir = '/data/datasets/mod'
|
||||
|
||||
saveHRpath = os.path.join(savedir, 'HR', 'x' + str(mod_scale))
|
||||
saveLRpath = os.path.join(savedir, 'LR', 'x' + str(up_scale))
|
||||
saveBicpath = os.path.join(savedir, 'Bic', 'x' + str(up_scale))
|
||||
|
||||
if not os.path.isdir(sourcedir):
|
||||
print('Error: No source data found')
|
||||
exit(0)
|
||||
if not os.path.isdir(savedir):
|
||||
os.mkdir(savedir)
|
||||
|
||||
if not os.path.isdir(os.path.join(savedir, 'HR')):
|
||||
os.mkdir(os.path.join(savedir, 'HR'))
|
||||
if not os.path.isdir(os.path.join(savedir, 'LR')):
|
||||
os.mkdir(os.path.join(savedir, 'LR'))
|
||||
if not os.path.isdir(os.path.join(savedir, 'Bic')):
|
||||
os.mkdir(os.path.join(savedir, 'Bic'))
|
||||
|
||||
if not os.path.isdir(saveHRpath):
|
||||
os.mkdir(saveHRpath)
|
||||
else:
|
||||
print('It will cover ' + str(saveHRpath))
|
||||
|
||||
if not os.path.isdir(saveLRpath):
|
||||
os.mkdir(saveLRpath)
|
||||
else:
|
||||
print('It will cover ' + str(saveLRpath))
|
||||
|
||||
if not os.path.isdir(saveBicpath):
|
||||
os.mkdir(saveBicpath)
|
||||
else:
|
||||
print('It will cover ' + str(saveBicpath))
|
||||
|
||||
filepaths = [f for f in os.listdir(sourcedir) if f.endswith('.png')]
|
||||
num_files = len(filepaths)
|
||||
|
||||
# prepare data with augementation
|
||||
for i in range(num_files):
|
||||
filename = filepaths[i]
|
||||
print('No.{} -- Processing {}'.format(i, filename))
|
||||
# read image
|
||||
image = cv2.imread(os.path.join(sourcedir, filename))
|
||||
|
||||
width = int(np.floor(image.shape[1] / mod_scale))
|
||||
height = int(np.floor(image.shape[0] / mod_scale))
|
||||
# modcrop
|
||||
if len(image.shape) == 3:
|
||||
image_HR = image[0:mod_scale * height, 0:mod_scale * width, :]
|
||||
else:
|
||||
image_HR = image[0:mod_scale * height, 0:mod_scale * width]
|
||||
# LR
|
||||
image_LR = imresize_np(image_HR, 1 / up_scale, True)
|
||||
# bic
|
||||
image_Bic = imresize_np(image_LR, up_scale, True)
|
||||
|
||||
cv2.imwrite(os.path.join(saveHRpath, filename), image_HR)
|
||||
cv2.imwrite(os.path.join(saveLRpath, filename), image_LR)
|
||||
cv2.imwrite(os.path.join(saveBicpath, filename), image_Bic)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
generate_mod_LR_bic()
|
|
@ -1,42 +0,0 @@
|
|||
|
||||
|
||||
echo "Prepare DIV2K X4 datasets..."
|
||||
cd ../../datasets
|
||||
mkdir DIV2K
|
||||
cd DIV2K
|
||||
|
||||
#### Step 1
|
||||
echo "Step 1: Download the datasets: [DIV2K_train_HR] and [DIV2K_train_LR_bicubic_X4]..."
|
||||
# GT
|
||||
FOLDER=DIV2K_train_HR
|
||||
FILE=DIV2K_train_HR.zip
|
||||
if [ ! -d "$FOLDER" ]; then
|
||||
if [ ! -f "$FILE" ]; then
|
||||
echo "Downloading $FILE..."
|
||||
wget http://data.vision.ee.ethz.ch/cvl/DIV2K/$FILE
|
||||
fi
|
||||
unzip $FILE
|
||||
fi
|
||||
# LR
|
||||
FOLDER=DIV2K_train_LR_bicubic
|
||||
FILE=DIV2K_train_LR_bicubic_X4.zip
|
||||
if [ ! -d "$FOLDER" ]; then
|
||||
if [ ! -f "$FILE" ]; then
|
||||
echo "Downloading $FILE..."
|
||||
wget http://data.vision.ee.ethz.ch/cvl/DIV2K/$FILE
|
||||
fi
|
||||
unzip $FILE
|
||||
fi
|
||||
|
||||
#### Step 2
|
||||
echo "Step 2: Rename the LR images..."
|
||||
cd ../../codes/data_scripts
|
||||
python rename.py
|
||||
|
||||
#### Step 4
|
||||
echo "Step 4: Crop to sub-images..."
|
||||
python extract_subimages.py
|
||||
|
||||
#### Step 5
|
||||
echo "Step5: Create LMDB files..."
|
||||
python create_lmdb.py
|
|
@ -1,11 +0,0 @@
|
|||
import os
|
||||
import glob
|
||||
|
||||
train_path = '/home/xtwang/datasets/REDS/train_sharp_bicubic/X4'
|
||||
val_path = '/home/xtwang/datasets/REDS/val_sharp_bicubic/X4'
|
||||
|
||||
# mv the val set
|
||||
val_folders = glob.glob(os.path.join(val_path, '*'))
|
||||
for folder in val_folders:
|
||||
new_folder_idx = '{:03d}'.format(int(folder.split('/')[-1]) + 240)
|
||||
os.system('cp -r {} {}'.format(folder, os.path.join(train_path, new_folder_idx)))
|
289
codes/train2.py
289
codes/train2.py
|
@ -1,289 +0,0 @@
|
|||
import os
|
||||
import math
|
||||
import argparse
|
||||
import random
|
||||
import logging
|
||||
import shutil
|
||||
from tqdm import tqdm
|
||||
|
||||
import torch
|
||||
from data.data_sampler import DistIterSampler
|
||||
|
||||
import options.options as option
|
||||
from utils import util
|
||||
from data import create_dataloader, create_dataset
|
||||
from models import create_model
|
||||
from time import time
|
||||
|
||||
|
||||
def init_dist(backend='nccl', **kwargs):
|
||||
# These packages have globals that screw with Windows, so only import them if needed.
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
"""initialization for distributed training"""
|
||||
if mp.get_start_method(allow_none=True) != 'spawn':
|
||||
mp.set_start_method('spawn')
|
||||
rank = int(os.environ['RANK'])
|
||||
num_gpus = torch.cuda.device_count()
|
||||
torch.cuda.set_device(rank % num_gpus)
|
||||
dist.init_process_group(backend=backend, **kwargs)
|
||||
|
||||
def main():
|
||||
#### options
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_mi1_spsr_switched2_fullimgref.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
|
||||
help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
opt = option.parse(args.opt, is_train=True)
|
||||
|
||||
colab_mode = False if 'colab_mode' not in opt.keys() else opt['colab_mode']
|
||||
if colab_mode:
|
||||
# Check the configuration of the remote server. Expect models, resume_state, and val_images directories to be there.
|
||||
# Each one should have a TEST file in it.
|
||||
util.get_files_from_server(opt['ssh_server'], opt['ssh_username'], opt['ssh_password'],
|
||||
os.path.join(opt['remote_path'], 'training_state', "TEST"))
|
||||
util.get_files_from_server(opt['ssh_server'], opt['ssh_username'], opt['ssh_password'],
|
||||
os.path.join(opt['remote_path'], 'models', "TEST"))
|
||||
util.get_files_from_server(opt['ssh_server'], opt['ssh_username'], opt['ssh_password'],
|
||||
os.path.join(opt['remote_path'], 'val_images', "TEST"))
|
||||
# Load the state and models needed from the remote server.
|
||||
if opt['path']['resume_state']:
|
||||
util.get_files_from_server(opt['ssh_server'], opt['ssh_username'], opt['ssh_password'], os.path.join(opt['remote_path'], 'training_state', opt['path']['resume_state']))
|
||||
if opt['path']['pretrain_model_G']:
|
||||
util.get_files_from_server(opt['ssh_server'], opt['ssh_username'], opt['ssh_password'], os.path.join(opt['remote_path'], 'models', opt['path']['pretrain_model_G']))
|
||||
if opt['path']['pretrain_model_D']:
|
||||
util.get_files_from_server(opt['ssh_server'], opt['ssh_username'], opt['ssh_password'], os.path.join(opt['remote_path'], 'models', opt['path']['pretrain_model_D']))
|
||||
|
||||
#### distributed training settings
|
||||
if args.launcher == 'none': # disabled distributed training
|
||||
opt['dist'] = False
|
||||
rank = -1
|
||||
print('Disabled distributed training.')
|
||||
else:
|
||||
opt['dist'] = True
|
||||
init_dist()
|
||||
world_size = torch.distributed.get_world_size()
|
||||
rank = torch.distributed.get_rank()
|
||||
|
||||
#### loading resume state if exists
|
||||
if opt['path'].get('resume_state', None):
|
||||
# distributed resuming: all load into default GPU
|
||||
device_id = torch.cuda.current_device()
|
||||
resume_state = torch.load(opt['path']['resume_state'],
|
||||
map_location=lambda storage, loc: storage.cuda(device_id))
|
||||
option.check_resume(opt, resume_state['iter']) # check resume options
|
||||
else:
|
||||
resume_state = None
|
||||
|
||||
#### mkdir and loggers
|
||||
if rank <= 0: # normal training (rank -1) OR distributed training (rank 0)
|
||||
if resume_state is None:
|
||||
util.mkdir_and_rename(
|
||||
opt['path']['experiments_root']) # rename experiment folder if exists
|
||||
util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root'
|
||||
and 'pretrain_model' not in key and 'resume' not in key))
|
||||
|
||||
# config loggers. Before it, the log will not work
|
||||
util.setup_logger('base', opt['path']['log'], 'train_' + opt['name'], level=logging.INFO,
|
||||
screen=True, tofile=True)
|
||||
logger = logging.getLogger('base')
|
||||
logger.info(option.dict2str(opt))
|
||||
# tensorboard logger
|
||||
if opt['use_tb_logger'] and 'debug' not in opt['name']:
|
||||
tb_logger_path = os.path.join(opt['path']['experiments_root'], 'tb_logger')
|
||||
version = float(torch.__version__[0:3])
|
||||
if version >= 1.1: # PyTorch 1.1
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
else:
|
||||
logger.info(
|
||||
'You are using PyTorch {}. Tensorboard will use [tensorboardX]'.format(version))
|
||||
from tensorboardX import SummaryWriter
|
||||
tb_logger = SummaryWriter(log_dir=tb_logger_path)
|
||||
else:
|
||||
util.setup_logger('base', opt['path']['log'], 'train', level=logging.INFO, screen=True)
|
||||
logger = logging.getLogger('base')
|
||||
|
||||
# convert to NoneDict, which returns None for missing keys
|
||||
opt = option.dict_to_nonedict(opt)
|
||||
|
||||
#### random seed
|
||||
seed = opt['train']['manual_seed']
|
||||
if seed is None:
|
||||
seed = random.randint(1, 10000)
|
||||
if rank <= 0:
|
||||
logger.info('Random seed: {}'.format(seed))
|
||||
util.set_random_seed(seed)
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
# torch.backends.cudnn.deterministic = True
|
||||
|
||||
#### create train and val dataloader
|
||||
dataset_ratio = 200 # enlarge the size of each epoch
|
||||
for phase, dataset_opt in opt['datasets'].items():
|
||||
if phase == 'train':
|
||||
train_set = create_dataset(dataset_opt)
|
||||
train_size = int(math.ceil(len(train_set) / dataset_opt['batch_size']))
|
||||
total_iters = int(opt['train']['niter'])
|
||||
total_epochs = int(math.ceil(total_iters / train_size))
|
||||
if opt['dist']:
|
||||
train_sampler = DistIterSampler(train_set, world_size, rank, dataset_ratio)
|
||||
total_epochs = int(math.ceil(total_iters / (train_size * dataset_ratio)))
|
||||
else:
|
||||
train_sampler = None
|
||||
train_loader = create_dataloader(train_set, dataset_opt, opt, train_sampler)
|
||||
if rank <= 0:
|
||||
logger.info('Number of train images: {:,d}, iters: {:,d}'.format(
|
||||
len(train_set), train_size))
|
||||
logger.info('Total epochs needed: {:d} for iters {:,d}'.format(
|
||||
total_epochs, total_iters))
|
||||
elif phase == 'val':
|
||||
val_set = create_dataset(dataset_opt)
|
||||
val_loader = create_dataloader(val_set, dataset_opt, opt, None)
|
||||
if rank <= 0:
|
||||
logger.info('Number of val images in [{:s}]: {:d}'.format(
|
||||
dataset_opt['name'], len(val_set)))
|
||||
else:
|
||||
raise NotImplementedError('Phase [{:s}] is not recognized.'.format(phase))
|
||||
assert train_loader is not None
|
||||
|
||||
#### create model
|
||||
model = create_model(opt)
|
||||
|
||||
#### resume training
|
||||
if resume_state:
|
||||
logger.info('Resuming training from epoch: {}, iter: {}.'.format(
|
||||
resume_state['epoch'], resume_state['iter']))
|
||||
|
||||
start_epoch = resume_state['epoch']
|
||||
current_step = resume_state['iter']
|
||||
model.resume_training(resume_state) # handle optimizers and schedulers
|
||||
else:
|
||||
current_step = -1 if 'start_step' not in opt.keys() else opt['start_step']
|
||||
start_epoch = 0
|
||||
|
||||
#### training
|
||||
logger.info('Start training from epoch: {:d}, iter: {:d}'.format(start_epoch, current_step))
|
||||
for epoch in range(start_epoch, total_epochs + 1):
|
||||
if opt['dist']:
|
||||
train_sampler.set_epoch(epoch)
|
||||
tq_ldr = tqdm(train_loader)
|
||||
|
||||
_t = time()
|
||||
_profile = False
|
||||
for _, train_data in enumerate(tq_ldr):
|
||||
if _profile:
|
||||
print("Data fetch: %f" % (time() - _t))
|
||||
_t = time()
|
||||
|
||||
current_step += 1
|
||||
if current_step > total_iters:
|
||||
break
|
||||
#### update learning rate
|
||||
model.update_learning_rate(current_step, warmup_iter=opt['train']['warmup_iter'])
|
||||
|
||||
#### training
|
||||
if _profile:
|
||||
print("Update LR: %f" % (time() - _t))
|
||||
_t = time()
|
||||
model.feed_data(train_data)
|
||||
model.optimize_parameters(current_step)
|
||||
if _profile:
|
||||
print("Model feed + step: %f" % (time() - _t))
|
||||
_t = time()
|
||||
|
||||
#### log
|
||||
if current_step % opt['logger']['print_freq'] == 0:
|
||||
logs = model.get_current_log(current_step)
|
||||
message = '[epoch:{:3d}, iter:{:8,d}, lr:('.format(epoch, current_step)
|
||||
for v in model.get_current_learning_rate():
|
||||
message += '{:.3e},'.format(v)
|
||||
message += ')] '
|
||||
for k, v in logs.items():
|
||||
if 'histogram' in k:
|
||||
if rank <= 0:
|
||||
tb_logger.add_histogram(k, v, current_step)
|
||||
else:
|
||||
message += '{:s}: {:.4e} '.format(k, v)
|
||||
# tensorboard logger
|
||||
if opt['use_tb_logger'] and 'debug' not in opt['name']:
|
||||
if rank <= 0:
|
||||
tb_logger.add_scalar(k, v, current_step)
|
||||
if rank <= 0:
|
||||
logger.info(message)
|
||||
#### validation
|
||||
if opt['datasets'].get('val', None) and current_step % opt['train']['val_freq'] == 0:
|
||||
if opt['model'] in ['sr', 'srgan', 'corruptgan', 'spsrgan', 'extensibletrainer'] and rank <= 0: # image restoration validation
|
||||
model.force_restore_swapout()
|
||||
val_batch_sz = 1 if 'batch_size' not in opt['datasets']['val'].keys() else opt['datasets']['val']['batch_size']
|
||||
# does not support multi-GPU validation
|
||||
pbar = util.ProgressBar(len(val_loader) * val_batch_sz)
|
||||
avg_psnr = 0.
|
||||
avg_fea_loss = 0.
|
||||
idx = 0
|
||||
colab_imgs_to_copy = []
|
||||
for val_data in val_loader:
|
||||
idx += 1
|
||||
for b in range(len(val_data['LQ_path'])):
|
||||
img_name = os.path.splitext(os.path.basename(val_data['LQ_path'][b]))[0]
|
||||
img_dir = os.path.join(opt['path']['val_images'], img_name)
|
||||
util.mkdir(img_dir)
|
||||
|
||||
model.feed_data(val_data)
|
||||
model.test()
|
||||
|
||||
visuals = model.get_current_visuals()
|
||||
if visuals is None:
|
||||
continue
|
||||
|
||||
sr_img = util.tensor2img(visuals['rlt'][b]) # uint8
|
||||
#gt_img = util.tensor2img(visuals['GT'][b]) # uint8
|
||||
|
||||
# Save SR images for reference
|
||||
img_base_name = '{:s}_{:d}.png'.format(img_name, current_step)
|
||||
save_img_path = os.path.join(img_dir, img_base_name)
|
||||
util.save_img(sr_img, save_img_path)
|
||||
if colab_mode:
|
||||
colab_imgs_to_copy.append(save_img_path)
|
||||
|
||||
# calculate PSNR (Naw - don't do that. PSNR sucks)
|
||||
#sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale'])
|
||||
#avg_psnr += util.calculate_psnr(sr_img, gt_img)
|
||||
#pbar.update('Test {}'.format(img_name))
|
||||
|
||||
# calculate fea loss
|
||||
avg_fea_loss += model.compute_fea_loss(visuals['rlt'][b], visuals['GT'][b])
|
||||
|
||||
if colab_mode:
|
||||
util.copy_files_to_server(opt['ssh_server'], opt['ssh_username'], opt['ssh_password'],
|
||||
colab_imgs_to_copy,
|
||||
os.path.join(opt['remote_path'], 'val_images', img_base_name))
|
||||
|
||||
avg_psnr = avg_psnr / idx
|
||||
avg_fea_loss = avg_fea_loss / idx
|
||||
|
||||
# log
|
||||
logger.info('# Validation # PSNR: {:.4e} Fea: {:.4e}'.format(avg_psnr, avg_fea_loss))
|
||||
# tensorboard logger
|
||||
if opt['use_tb_logger'] and 'debug' not in opt['name']:
|
||||
#tb_logger.add_scalar('val_psnr', avg_psnr, current_step)
|
||||
tb_logger.add_scalar('val_fea', avg_fea_loss, current_step)
|
||||
|
||||
#### save models and training states
|
||||
if current_step % opt['logger']['save_checkpoint_freq'] == 0:
|
||||
if rank <= 0:
|
||||
logger.info('Saving models and training states.')
|
||||
model.save(current_step)
|
||||
model.save_training_state(epoch, current_step)
|
||||
|
||||
if rank <= 0:
|
||||
logger.info('Saving the final model.')
|
||||
model.save('latest')
|
||||
logger.info('End of training.')
|
||||
tb_logger.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Loading…
Reference in New Issue
Block a user