296135ec18
doResizeLoss has a 50% chance to resize the LQ image to 50% size, then back to original size. This is useful to training a generator to recover these lost pixel values while also being able to do repairs on higher resolution images during training.
269 lines
12 KiB
Python
269 lines
12 KiB
Python
import os
|
|
import math
|
|
import argparse
|
|
import random
|
|
import logging
|
|
import shutil
|
|
from tqdm import tqdm
|
|
|
|
import torch
|
|
#import torch.distributed as dist
|
|
#import torch.multiprocessing as mp
|
|
from data.data_sampler import DistIterSampler
|
|
|
|
import options.options as option
|
|
from utils import util
|
|
from data import create_dataloader, create_dataset
|
|
from models import create_model
|
|
|
|
|
|
def init_dist(backend='nccl', **kwargs):
|
|
"""initialization for distributed training"""
|
|
if mp.get_start_method(allow_none=True) != 'spawn':
|
|
mp.set_start_method('spawn')
|
|
rank = int(os.environ['RANK'])
|
|
num_gpus = torch.cuda.device_count()
|
|
torch.cuda.set_device(rank % num_gpus)
|
|
dist.init_process_group(backend=backend, **kwargs)
|
|
|
|
|
|
def main():
|
|
#### options
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_switched_rrdb_small.yml')
|
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
|
|
help='job launcher')
|
|
parser.add_argument('--local_rank', type=int, default=0)
|
|
args = parser.parse_args()
|
|
opt = option.parse(args.opt, is_train=True)
|
|
|
|
colab_mode = False if 'colab_mode' not in opt.keys() else opt['colab_mode']
|
|
if colab_mode:
|
|
# Check the configuration of the remote server. Expect models, resume_state, and val_images directories to be there.
|
|
# Each one should have a TEST file in it.
|
|
util.get_files_from_server(opt['ssh_server'], opt['ssh_username'], opt['ssh_password'],
|
|
os.path.join(opt['remote_path'], 'training_state', "TEST"))
|
|
util.get_files_from_server(opt['ssh_server'], opt['ssh_username'], opt['ssh_password'],
|
|
os.path.join(opt['remote_path'], 'models', "TEST"))
|
|
util.get_files_from_server(opt['ssh_server'], opt['ssh_username'], opt['ssh_password'],
|
|
os.path.join(opt['remote_path'], 'val_images', "TEST"))
|
|
# Load the state and models needed from the remote server.
|
|
if opt['path']['resume_state']:
|
|
util.get_files_from_server(opt['ssh_server'], opt['ssh_username'], opt['ssh_password'], os.path.join(opt['remote_path'], 'training_state', opt['path']['resume_state']))
|
|
if opt['path']['pretrain_model_G']:
|
|
util.get_files_from_server(opt['ssh_server'], opt['ssh_username'], opt['ssh_password'], os.path.join(opt['remote_path'], 'models', opt['path']['pretrain_model_G']))
|
|
if opt['path']['pretrain_model_D']:
|
|
util.get_files_from_server(opt['ssh_server'], opt['ssh_username'], opt['ssh_password'], os.path.join(opt['remote_path'], 'models', opt['path']['pretrain_model_D']))
|
|
|
|
#### distributed training settings
|
|
if args.launcher == 'none': # disabled distributed training
|
|
opt['dist'] = False
|
|
rank = -1
|
|
print('Disabled distributed training.')
|
|
else:
|
|
opt['dist'] = True
|
|
init_dist()
|
|
world_size = torch.distributed.get_world_size()
|
|
rank = torch.distributed.get_rank()
|
|
|
|
#### loading resume state if exists
|
|
if opt['path'].get('resume_state', None):
|
|
# distributed resuming: all load into default GPU
|
|
device_id = torch.cuda.current_device()
|
|
resume_state = torch.load(opt['path']['resume_state'],
|
|
map_location=lambda storage, loc: storage.cuda(device_id))
|
|
option.check_resume(opt, resume_state['iter']) # check resume options
|
|
else:
|
|
resume_state = None
|
|
|
|
#### mkdir and loggers
|
|
if rank <= 0: # normal training (rank -1) OR distributed training (rank 0)
|
|
if resume_state is None:
|
|
util.mkdir_and_rename(
|
|
opt['path']['experiments_root']) # rename experiment folder if exists
|
|
util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root'
|
|
and 'pretrain_model' not in key and 'resume' not in key))
|
|
|
|
# config loggers. Before it, the log will not work
|
|
util.setup_logger('base', opt['path']['log'], 'train_' + opt['name'], level=logging.INFO,
|
|
screen=True, tofile=True)
|
|
logger = logging.getLogger('base')
|
|
logger.info(option.dict2str(opt))
|
|
# tensorboard logger
|
|
if opt['use_tb_logger'] and 'debug' not in opt['name']:
|
|
tb_logger_path = '../tb_logger/' + opt['name']
|
|
|
|
# If not resuming, delete the existing logs. Tensorboard doesn't do too great with these.
|
|
if opt['path'].get('resume_state', None) is None:
|
|
if os.path.exists(tb_logger_path) and os.path.isdir(tb_logger_path):
|
|
shutil.rmtree(tb_logger_path)
|
|
|
|
version = float(torch.__version__[0:3])
|
|
if version >= 1.1: # PyTorch 1.1
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
else:
|
|
logger.info(
|
|
'You are using PyTorch {}. Tensorboard will use [tensorboardX]'.format(version))
|
|
from tensorboardX import SummaryWriter
|
|
tb_logger = SummaryWriter(log_dir=tb_logger_path)
|
|
else:
|
|
util.setup_logger('base', opt['path']['log'], 'train', level=logging.INFO, screen=True)
|
|
logger = logging.getLogger('base')
|
|
|
|
# convert to NoneDict, which returns None for missing keys
|
|
opt = option.dict_to_nonedict(opt)
|
|
|
|
#### random seed
|
|
seed = opt['train']['manual_seed']
|
|
if seed is None:
|
|
seed = random.randint(1, 10000)
|
|
if rank <= 0:
|
|
logger.info('Random seed: {}'.format(seed))
|
|
util.set_random_seed(seed)
|
|
|
|
torch.backends.cudnn.benchmark = True
|
|
# torch.backends.cudnn.deterministic = True
|
|
|
|
#### create train and val dataloader
|
|
dataset_ratio = 200 # enlarge the size of each epoch
|
|
for phase, dataset_opt in opt['datasets'].items():
|
|
if phase == 'train':
|
|
train_set = create_dataset(dataset_opt)
|
|
train_size = int(math.ceil(len(train_set) / dataset_opt['batch_size']))
|
|
total_iters = int(opt['train']['niter'])
|
|
total_epochs = int(math.ceil(total_iters / train_size))
|
|
if opt['dist']:
|
|
train_sampler = DistIterSampler(train_set, world_size, rank, dataset_ratio)
|
|
total_epochs = int(math.ceil(total_iters / (train_size * dataset_ratio)))
|
|
else:
|
|
train_sampler = None
|
|
train_loader = create_dataloader(train_set, dataset_opt, opt, train_sampler)
|
|
if rank <= 0:
|
|
logger.info('Number of train images: {:,d}, iters: {:,d}'.format(
|
|
len(train_set), train_size))
|
|
logger.info('Total epochs needed: {:d} for iters {:,d}'.format(
|
|
total_epochs, total_iters))
|
|
elif phase == 'val':
|
|
val_set = create_dataset(dataset_opt)
|
|
val_loader = create_dataloader(val_set, dataset_opt, opt, None)
|
|
if rank <= 0:
|
|
logger.info('Number of val images in [{:s}]: {:d}'.format(
|
|
dataset_opt['name'], len(val_set)))
|
|
else:
|
|
raise NotImplementedError('Phase [{:s}] is not recognized.'.format(phase))
|
|
assert train_loader is not None
|
|
|
|
#### create model
|
|
model = create_model(opt)
|
|
|
|
#### resume training
|
|
if resume_state:
|
|
logger.info('Resuming training from epoch: {}, iter: {}.'.format(
|
|
resume_state['epoch'], resume_state['iter']))
|
|
|
|
start_epoch = resume_state['epoch']
|
|
current_step = resume_state['iter']
|
|
model.resume_training(resume_state) # handle optimizers and schedulers
|
|
else:
|
|
current_step = -1
|
|
start_epoch = 0
|
|
|
|
#### training
|
|
logger.info('Start training from epoch: {:d}, iter: {:d}'.format(start_epoch, current_step))
|
|
for epoch in range(start_epoch, total_epochs + 1):
|
|
if opt['dist']:
|
|
train_sampler.set_epoch(epoch)
|
|
tq_ldr = tqdm(train_loader)
|
|
for _, train_data in enumerate(tq_ldr):
|
|
current_step += 1
|
|
if current_step > total_iters:
|
|
break
|
|
#### update learning rate
|
|
model.update_learning_rate(current_step, warmup_iter=opt['train']['warmup_iter'])
|
|
|
|
#### training
|
|
model.feed_data(train_data)
|
|
model.optimize_parameters(current_step)
|
|
|
|
#### log
|
|
if current_step % opt['logger']['print_freq'] == 0:
|
|
logs = model.get_current_log(current_step)
|
|
message = '[epoch:{:3d}, iter:{:8,d}, lr:('.format(epoch, current_step)
|
|
for v in model.get_current_learning_rate():
|
|
message += '{:.3e},'.format(v)
|
|
message += ')] '
|
|
for k, v in logs.items():
|
|
message += '{:s}: {:.4e} '.format(k, v)
|
|
# tensorboard logger
|
|
if opt['use_tb_logger'] and 'debug' not in opt['name']:
|
|
if rank <= 0:
|
|
tb_logger.add_scalar(k, v, current_step)
|
|
if rank <= 0:
|
|
logger.info(message)
|
|
#### validation
|
|
if opt['datasets'].get('val', None) and current_step % opt['train']['val_freq'] == 0:
|
|
if opt['model'] in ['sr', 'srgan', 'corruptgan'] and rank <= 0: # image restoration validation
|
|
model.force_restore_swapout()
|
|
val_batch_sz = 1 if 'batch_size' not in opt['datasets']['val'].keys() else opt['datasets']['val']['batch_size']
|
|
# does not support multi-GPU validation
|
|
pbar = util.ProgressBar(len(val_loader) * val_batch_sz)
|
|
avg_psnr = 0.
|
|
idx = 0
|
|
colab_imgs_to_copy = []
|
|
for val_data in val_loader:
|
|
idx += 1
|
|
for b in range(len(val_data['LQ_path'])):
|
|
img_name = os.path.splitext(os.path.basename(val_data['LQ_path'][b]))[0]
|
|
img_dir = os.path.join(opt['path']['val_images'], img_name)
|
|
util.mkdir(img_dir)
|
|
|
|
model.feed_data(val_data)
|
|
model.test()
|
|
|
|
visuals = model.get_current_visuals()
|
|
|
|
sr_img = util.tensor2img(visuals['rlt'][b]) # uint8
|
|
gt_img = util.tensor2img(visuals['GT'][b]) # uint8
|
|
|
|
# Save SR images for reference
|
|
img_base_name = '{:s}_{:d}.png'.format(img_name, current_step)
|
|
save_img_path = os.path.join(img_dir, img_base_name)
|
|
util.save_img(sr_img, save_img_path)
|
|
if colab_mode:
|
|
colab_imgs_to_copy.append(save_img_path)
|
|
|
|
# calculate PSNR
|
|
sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale'])
|
|
avg_psnr += util.calculate_psnr(sr_img, gt_img)
|
|
pbar.update('Test {}'.format(img_name))
|
|
|
|
if colab_mode:
|
|
util.copy_files_to_server(opt['ssh_server'], opt['ssh_username'], opt['ssh_password'],
|
|
colab_imgs_to_copy,
|
|
os.path.join(opt['remote_path'], 'val_images', img_base_name))
|
|
|
|
avg_psnr = avg_psnr / idx
|
|
|
|
# log
|
|
logger.info('# Validation # PSNR: {:.4e}'.format(avg_psnr))
|
|
# tensorboard logger
|
|
if opt['use_tb_logger'] and 'debug' not in opt['name']:
|
|
tb_logger.add_scalar('psnr', avg_psnr, current_step)
|
|
|
|
#### save models and training states
|
|
if current_step % opt['logger']['save_checkpoint_freq'] == 0:
|
|
if rank <= 0:
|
|
logger.info('Saving models and training states.')
|
|
model.save(current_step)
|
|
model.save_training_state(epoch, current_step)
|
|
|
|
if rank <= 0:
|
|
logger.info('Saving the final model.')
|
|
model.save('latest')
|
|
logger.info('End of training.')
|
|
tb_logger.close()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|