24792bdb4f
Removed a lot of legacy stuff I have no intent on using again. Plan is to shape this repo into something more extensible (get it? hah!)
308 lines
14 KiB
Python
308 lines
14 KiB
Python
import os
|
|
import math
|
|
import argparse
|
|
import random
|
|
import logging
|
|
from tqdm import tqdm
|
|
|
|
import torch
|
|
from data.data_sampler import DistIterSampler
|
|
|
|
from utils import util, options as option
|
|
from data import create_dataloader, create_dataset
|
|
from models.ExtensibleTrainer import ExtensibleTrainer
|
|
from time import time
|
|
|
|
|
|
def init_dist(backend='nccl', **kwargs):
|
|
# These packages have globals that screw with Windows, so only import them if needed.
|
|
import torch.distributed as dist
|
|
import torch.multiprocessing as mp
|
|
|
|
"""initialization for distributed training"""
|
|
if mp.get_start_method(allow_none=True) != 'spawn':
|
|
mp.set_start_method('spawn')
|
|
rank = int(os.environ['RANK'])
|
|
num_gpus = torch.cuda.device_count()
|
|
torch.cuda.set_device(rank % num_gpus)
|
|
dist.init_process_group(backend=backend, **kwargs)
|
|
|
|
def main():
|
|
#### options
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_imgset_bigswitch.yml')
|
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
|
parser.add_argument('--local_rank', type=int, default=0)
|
|
args = parser.parse_args()
|
|
opt = option.parse(args.opt, is_train=True)
|
|
|
|
colab_mode = False if 'colab_mode' not in opt.keys() else opt['colab_mode']
|
|
if colab_mode:
|
|
# Check the configuration of the remote server. Expect models, resume_state, and val_images directories to be there.
|
|
# Each one should have a TEST file in it.
|
|
util.get_files_from_server(opt['ssh_server'], opt['ssh_username'], opt['ssh_password'],
|
|
os.path.join(opt['remote_path'], 'training_state', "TEST"))
|
|
util.get_files_from_server(opt['ssh_server'], opt['ssh_username'], opt['ssh_password'],
|
|
os.path.join(opt['remote_path'], 'models', "TEST"))
|
|
util.get_files_from_server(opt['ssh_server'], opt['ssh_username'], opt['ssh_password'],
|
|
os.path.join(opt['remote_path'], 'val_images', "TEST"))
|
|
# Load the state and models needed from the remote server.
|
|
if opt['path']['resume_state']:
|
|
util.get_files_from_server(opt['ssh_server'], opt['ssh_username'], opt['ssh_password'], os.path.join(opt['remote_path'], 'training_state', opt['path']['resume_state']))
|
|
if opt['path']['pretrain_model_G']:
|
|
util.get_files_from_server(opt['ssh_server'], opt['ssh_username'], opt['ssh_password'], os.path.join(opt['remote_path'], 'models', opt['path']['pretrain_model_G']))
|
|
if opt['path']['pretrain_model_D']:
|
|
util.get_files_from_server(opt['ssh_server'], opt['ssh_username'], opt['ssh_password'], os.path.join(opt['remote_path'], 'models', opt['path']['pretrain_model_D']))
|
|
|
|
#### distributed training settings
|
|
if len(opt['gpu_ids']) == 1 and torch.cuda.device_count() > 1:
|
|
gpu = input('I noticed you have multiple GPUs. Starting two jobs on the same GPU sucks. Please confirm which GPU'
|
|
'you want to use. Press enter to use the specified one [%s]' % (opt['gpu_ids']))
|
|
if gpu:
|
|
opt['gpu_ids'] = [int(gpu)]
|
|
if args.launcher == 'none': # disabled distributed training
|
|
opt['dist'] = False
|
|
rank = -1
|
|
print('Disabled distributed training.')
|
|
|
|
else:
|
|
opt['dist'] = True
|
|
init_dist()
|
|
world_size = torch.distributed.get_world_size()
|
|
rank = torch.distributed.get_rank()
|
|
|
|
#### loading resume state if exists
|
|
if opt['path'].get('resume_state', None):
|
|
# distributed resuming: all load into default GPU
|
|
device_id = torch.cuda.current_device()
|
|
resume_state = torch.load(opt['path']['resume_state'],
|
|
map_location=lambda storage, loc: storage.cuda(device_id))
|
|
option.check_resume(opt, resume_state['iter']) # check resume options
|
|
else:
|
|
resume_state = None
|
|
|
|
#### mkdir and loggers
|
|
if rank <= 0: # normal training (rank -1) OR distributed training (rank 0)
|
|
if resume_state is None:
|
|
util.mkdir_and_rename(
|
|
opt['path']['experiments_root']) # rename experiment folder if exists
|
|
util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root' and path is not None
|
|
and 'pretrain_model' not in key and 'resume' not in key))
|
|
|
|
# config loggers. Before it, the log will not work
|
|
util.setup_logger('base', opt['path']['log'], 'train_' + opt['name'], level=logging.INFO,
|
|
screen=True, tofile=True)
|
|
logger = logging.getLogger('base')
|
|
logger.info(option.dict2str(opt))
|
|
# tensorboard logger
|
|
if opt['use_tb_logger'] and 'debug' not in opt['name']:
|
|
tb_logger_path = os.path.join(opt['path']['experiments_root'], 'tb_logger')
|
|
version = float(torch.__version__[0:3])
|
|
if version >= 1.1: # PyTorch 1.1
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
else:
|
|
logger.info(
|
|
'You are using PyTorch {}. Tensorboard will use [tensorboardX]'.format(version))
|
|
from tensorboardX import SummaryWriter
|
|
tb_logger = SummaryWriter(log_dir=tb_logger_path)
|
|
else:
|
|
util.setup_logger('base', opt['path']['log'], 'train', level=logging.INFO, screen=True)
|
|
logger = logging.getLogger('base')
|
|
|
|
# convert to NoneDict, which returns None for missing keys
|
|
opt = option.dict_to_nonedict(opt)
|
|
|
|
#### random seed
|
|
seed = opt['train']['manual_seed']
|
|
if seed is None:
|
|
seed = random.randint(1, 10000)
|
|
if rank <= 0:
|
|
logger.info('Random seed: {}'.format(seed))
|
|
util.set_random_seed(seed)
|
|
|
|
torch.backends.cudnn.benchmark = True
|
|
# torch.backends.cudnn.deterministic = True
|
|
# torch.autograd.set_detect_anomaly(True)
|
|
|
|
# Save the compiled opt dict to the global loaded_options variable.
|
|
util.loaded_options = opt
|
|
|
|
#### create train and val dataloader
|
|
dataset_ratio = 1 # enlarge the size of each epoch
|
|
for phase, dataset_opt in opt['datasets'].items():
|
|
if phase == 'train':
|
|
train_set = create_dataset(dataset_opt)
|
|
train_size = int(math.ceil(len(train_set) / dataset_opt['batch_size']))
|
|
total_iters = int(opt['train']['niter'])
|
|
total_epochs = int(math.ceil(total_iters / train_size))
|
|
if opt['dist']:
|
|
train_sampler = DistIterSampler(train_set, world_size, rank, dataset_ratio)
|
|
total_epochs = int(math.ceil(total_iters / (train_size * dataset_ratio)))
|
|
else:
|
|
train_sampler = None
|
|
train_loader = create_dataloader(train_set, dataset_opt, opt, train_sampler)
|
|
if rank <= 0:
|
|
logger.info('Number of train images: {:,d}, iters: {:,d}'.format(
|
|
len(train_set), train_size))
|
|
logger.info('Total epochs needed: {:d} for iters {:,d}'.format(
|
|
total_epochs, total_iters))
|
|
elif phase == 'val':
|
|
val_set = create_dataset(dataset_opt)
|
|
val_loader = create_dataloader(val_set, dataset_opt, opt, None)
|
|
if rank <= 0:
|
|
logger.info('Number of val images in [{:s}]: {:d}'.format(
|
|
dataset_opt['name'], len(val_set)))
|
|
else:
|
|
raise NotImplementedError('Phase [{:s}] is not recognized.'.format(phase))
|
|
assert train_loader is not None
|
|
|
|
#### create model
|
|
model = ExtensibleTrainer(opt)
|
|
|
|
#### resume training
|
|
if resume_state:
|
|
logger.info('Resuming training from epoch: {}, iter: {}.'.format(
|
|
resume_state['epoch'], resume_state['iter']))
|
|
|
|
start_epoch = resume_state['epoch']
|
|
current_step = resume_state['iter']
|
|
model.resume_training(resume_state, 'amp_opt_level' in opt.keys()) # handle optimizers and schedulers
|
|
else:
|
|
current_step = -1 if 'start_step' not in opt.keys() else opt['start_step']
|
|
start_epoch = 0
|
|
|
|
#### training
|
|
logger.info('Start training from epoch: {:d}, iter: {:d}'.format(start_epoch, current_step))
|
|
for epoch in range(start_epoch, total_epochs + 1):
|
|
if opt['dist']:
|
|
train_sampler.set_epoch(epoch)
|
|
tq_ldr = tqdm(train_loader)
|
|
|
|
_t = time()
|
|
_profile = False
|
|
for train_data in tq_ldr:
|
|
if _profile:
|
|
print("Data fetch: %f" % (time() - _t))
|
|
_t = time()
|
|
|
|
#tb_logger.add_graph(model.netsG['generator'].module, [train_data['LQ'].to('cuda'),
|
|
# train_data['lq_fullsize_ref'].float().to('cuda'),
|
|
# train_data['lq_center'].to('cuda')])
|
|
|
|
current_step += 1
|
|
if current_step > total_iters:
|
|
break
|
|
#### update learning rate
|
|
model.update_learning_rate(current_step, warmup_iter=opt['train']['warmup_iter'])
|
|
|
|
#### training
|
|
if _profile:
|
|
print("Update LR: %f" % (time() - _t))
|
|
_t = time()
|
|
model.feed_data(train_data)
|
|
model.optimize_parameters(current_step)
|
|
if _profile:
|
|
print("Model feed + step: %f" % (time() - _t))
|
|
_t = time()
|
|
|
|
#### log
|
|
if current_step % opt['logger']['print_freq'] == 0:
|
|
logs = model.get_current_log(current_step)
|
|
message = '[epoch:{:3d}, iter:{:8,d}, lr:('.format(epoch, current_step)
|
|
for v in model.get_current_learning_rate():
|
|
message += '{:.3e},'.format(v)
|
|
message += ')] '
|
|
for k, v in logs.items():
|
|
if 'histogram' in k:
|
|
if rank <= 0:
|
|
tb_logger.add_histogram(k, v, current_step)
|
|
else:
|
|
message += '{:s}: {:.4e} '.format(k, v)
|
|
# tensorboard logger
|
|
if opt['use_tb_logger'] and 'debug' not in opt['name']:
|
|
if rank <= 0:
|
|
tb_logger.add_scalar(k, v, current_step)
|
|
if rank <= 0:
|
|
logger.info(message)
|
|
|
|
#### save models and training states
|
|
if current_step % opt['logger']['save_checkpoint_freq'] == 0:
|
|
if rank <= 0:
|
|
logger.info('Saving models and training states.')
|
|
model.save(current_step)
|
|
model.save_training_state(epoch, current_step)
|
|
if 'alt_path' in opt['path'].keys():
|
|
import shutil
|
|
print("Synchronizing tb_logger to alt_path..")
|
|
alt_tblogger = os.path.join(opt['path']['alt_path'], "tb_logger")
|
|
shutil.rmtree(alt_tblogger, ignore_errors=True)
|
|
shutil.copytree(tb_logger_path, alt_tblogger)
|
|
|
|
#### validation
|
|
if opt['datasets'].get('val', None) and current_step % opt['train']['val_freq'] == 0:
|
|
if opt['model'] in ['sr', 'srgan', 'corruptgan', 'spsrgan', 'extensibletrainer'] and rank <= 0: # image restoration validation
|
|
model.force_restore_swapout()
|
|
val_batch_sz = 1 if 'batch_size' not in opt['datasets']['val'].keys() else opt['datasets']['val']['batch_size']
|
|
# does not support multi-GPU validation
|
|
avg_psnr = 0.
|
|
avg_fea_loss = 0.
|
|
idx = 0
|
|
colab_imgs_to_copy = []
|
|
val_tqdm = tqdm(val_loader)
|
|
for val_data in val_tqdm:
|
|
idx += 1
|
|
for b in range(len(val_data['LQ_path'])):
|
|
img_name = os.path.splitext(os.path.basename(val_data['LQ_path'][b]))[0]
|
|
img_dir = os.path.join(opt['path']['val_images'], img_name)
|
|
util.mkdir(img_dir)
|
|
|
|
model.feed_data(val_data)
|
|
model.test()
|
|
|
|
visuals = model.get_current_visuals()
|
|
if visuals is None:
|
|
continue
|
|
|
|
sr_img = util.tensor2img(visuals['rlt'][b]) # uint8
|
|
#gt_img = util.tensor2img(visuals['GT'][b]) # uint8
|
|
|
|
# Save SR images for reference
|
|
img_base_name = '{:s}_{:d}.png'.format(img_name, current_step)
|
|
save_img_path = os.path.join(img_dir, img_base_name)
|
|
util.save_img(sr_img, save_img_path)
|
|
if colab_mode:
|
|
colab_imgs_to_copy.append(save_img_path)
|
|
|
|
# calculate PSNR (Naw - don't do that. PSNR sucks)
|
|
#sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale'])
|
|
#avg_psnr += util.calculate_psnr(sr_img, gt_img)
|
|
#pbar.update('Test {}'.format(img_name))
|
|
|
|
# calculate fea loss
|
|
avg_fea_loss += model.compute_fea_loss(visuals['rlt'][b], visuals['GT'][b])
|
|
|
|
if colab_mode:
|
|
util.copy_files_to_server(opt['ssh_server'], opt['ssh_username'], opt['ssh_password'],
|
|
colab_imgs_to_copy,
|
|
os.path.join(opt['remote_path'], 'val_images', img_base_name))
|
|
|
|
avg_psnr = avg_psnr / idx
|
|
avg_fea_loss = avg_fea_loss / idx
|
|
|
|
# log
|
|
logger.info('# Validation # PSNR: {:.4e} Fea: {:.4e}'.format(avg_psnr, avg_fea_loss))
|
|
# tensorboard logger
|
|
if opt['use_tb_logger'] and 'debug' not in opt['name'] and rank <= 0:
|
|
#tb_logger.add_scalar('val_psnr', avg_psnr, current_step)
|
|
tb_logger.add_scalar('val_fea', avg_fea_loss, current_step)
|
|
|
|
if rank <= 0:
|
|
logger.info('Saving the final model.')
|
|
model.save('latest')
|
|
logger.info('End of training.')
|
|
tb_logger.close()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|