From 3e3d2af1f31ca50146f722e73e2cc72613dd7ffd Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 22 Oct 2020 13:27:32 -0600 Subject: [PATCH] Add multi-modal trainer --- codes/models/ExtensibleTrainer.py | 12 +- codes/models/steps/losses.py | 5 +- codes/multi_modal_train.py | 45 +++++ codes/train.py | 279 ++++++++++++++++++++++++++---- codes/train2.py | 2 +- 5 files changed, 303 insertions(+), 40 deletions(-) create mode 100644 codes/multi_modal_train.py diff --git a/codes/models/ExtensibleTrainer.py b/codes/models/ExtensibleTrainer.py index 0ef0afd5..7b9878eb 100644 --- a/codes/models/ExtensibleTrainer.py +++ b/codes/models/ExtensibleTrainer.py @@ -19,7 +19,7 @@ logger = logging.getLogger('base') class ExtensibleTrainer(BaseModel): - def __init__(self, opt): + def __init__(self, opt, cached_networks={}): super(ExtensibleTrainer, self).__init__(opt) if opt['dist']: self.rank = torch.distributed.get_rank() @@ -49,11 +49,17 @@ class ExtensibleTrainer(BaseModel): if 'trainable' not in net.keys(): net['trainable'] = True + if name in cached_networks.keys(): + new_net = cached_networks[name] + else: + new_net = None if net['type'] == 'generator': - new_net = networks.define_G(net, None, opt['scale']).to(self.device) + if new_net is None: + new_net = networks.define_G(net, None, opt['scale']).to(self.device) self.netsG[name] = new_net elif net['type'] == 'discriminator': - new_net = networks.define_D_net(net, opt['datasets']['train']['target_size']).to(self.device) + if new_net is None: + new_net = networks.define_D_net(net, opt['datasets']['train']['target_size']).to(self.device) self.netsD[name] = new_net else: raise NotImplementedError("Can only handle generators and discriminators") diff --git a/codes/models/steps/losses.py b/codes/models/steps/losses.py index 9f0093e4..5cd8e174 100644 --- a/codes/models/steps/losses.py +++ b/codes/models/steps/losses.py @@ -41,7 +41,10 @@ def extract_params_from_state(params: object, state: object, root: object = True if isinstance(params, list) or isinstance(params, tuple): p = [extract_params_from_state(r, state, False) for r in params] elif isinstance(params, str): - p = state[params] + if params == 'None': + p = None + else: + p = state[params] else: p = params # The root return must always be a list. diff --git a/codes/multi_modal_train.py b/codes/multi_modal_train.py new file mode 100644 index 00000000..3e873622 --- /dev/null +++ b/codes/multi_modal_train.py @@ -0,0 +1,45 @@ +# This is a wrapper around train.py which allows you to train a set of models using a variety of different training +# paradigms. This works by using the yielding mechanism built into train.py to iterate one step at a time and +# synchronize the underlying models. +# +# Note that this wrapper is **EXTREMELY** simple and doesn't attempt to do many things. Some issues you should plan for: +# 1) Each trainer will have its own optimizer for the underlying model - even when the model is shared. +# 2) Each trainer will run validation and save model states according to its own schedule. Likewise: +# 3) Each trainer will load state params for the models it controls independently, regardless of whether or not those +# models are shared. Your best bet is to have all models save state at the same time so that they all load ~ the same +# state when re-started. +import argparse +import train +import utils.options as option + +def main(master_opt, launcher): + trainers = [] + all_networks = {} + shared_networks = [] + for i, sub_opt in enumerate(master_opt['trainer_options']): + sub_opt_parsed = option.parse(sub_opt, is_train=True) + # This creates trainers() as a list of generators. + train_gen = train.yielding_main(sub_opt_parsed, launcher, i, all_networks) + model = next(train_gen) + for k, v in model.networks.items(): + if k in all_networks.keys() and k not in shared_networks: + shared_networks.append(k) + all_networks[k] = v + trainers.append(train_gen) + print("Networks being shared by trainers: ", shared_networks) + + # Now, simply "iterate" through the trainers to accomplish training. + while True: + for trainer in trainers: + next(trainer) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + #parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_imgset_chained_structured_trans_invariance.yml') + parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') + args = parser.parse_args() + opt = { + 'trainer_options': ['../options/teco.yml', '../options/exd.yml'] + } + main(opt, args.launcher) \ No newline at end of file diff --git a/codes/train.py b/codes/train.py index 931a4c95..954f7d07 100644 --- a/codes/train.py +++ b/codes/train.py @@ -27,40 +27,15 @@ def init_dist(backend='nccl', **kwargs): torch.cuda.set_device(rank % num_gpus) dist.init_process_group(backend=backend, **kwargs) -def main(): - #### options - parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_imgset_chained_structured_trans_invariance.yml') - parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') - parser.add_argument('--local_rank', type=int, default=0) - args = parser.parse_args() - opt = option.parse(args.opt, is_train=True) - - colab_mode = False if 'colab_mode' not in opt.keys() else opt['colab_mode'] - if colab_mode: - # Check the configuration of the remote server. Expect models, resume_state, and val_images directories to be there. - # Each one should have a TEST file in it. - util.get_files_from_server(opt['ssh_server'], opt['ssh_username'], opt['ssh_password'], - os.path.join(opt['remote_path'], 'training_state', "TEST")) - util.get_files_from_server(opt['ssh_server'], opt['ssh_username'], opt['ssh_password'], - os.path.join(opt['remote_path'], 'models', "TEST")) - util.get_files_from_server(opt['ssh_server'], opt['ssh_username'], opt['ssh_password'], - os.path.join(opt['remote_path'], 'val_images', "TEST")) - # Load the state and models needed from the remote server. - if opt['path']['resume_state']: - util.get_files_from_server(opt['ssh_server'], opt['ssh_username'], opt['ssh_password'], os.path.join(opt['remote_path'], 'training_state', opt['path']['resume_state'])) - if opt['path']['pretrain_model_G']: - util.get_files_from_server(opt['ssh_server'], opt['ssh_username'], opt['ssh_password'], os.path.join(opt['remote_path'], 'models', opt['path']['pretrain_model_G'])) - if opt['path']['pretrain_model_D']: - util.get_files_from_server(opt['ssh_server'], opt['ssh_username'], opt['ssh_password'], os.path.join(opt['remote_path'], 'models', opt['path']['pretrain_model_D'])) +def main(opt, launcher='none'): #### distributed training settings if len(opt['gpu_ids']) == 1 and torch.cuda.device_count() > 1: gpu = input('I noticed you have multiple GPUs. Starting two jobs on the same GPU sucks. Please confirm which GPU' 'you want to use. Press enter to use the specified one [%s]' % (opt['gpu_ids'])) if gpu: opt['gpu_ids'] = [int(gpu)] - if args.launcher == 'none': # disabled distributed training + if launcher == 'none': # disabled distributed training opt['dist'] = False rank = -1 print('Disabled distributed training.') @@ -257,9 +232,6 @@ def main(): if visuals is None: continue - if colab_mode: - colab_imgs_to_copy.append(save_img_path) - # calculate PSNR sr_img = util.tensor2img(visuals['rlt'][b]) # uint8 gt_img = util.tensor2img(visuals['GT'][b]) # uint8 @@ -274,10 +246,242 @@ def main(): save_img_path = os.path.join(img_dir, img_base_name) util.save_img(sr_img, save_img_path) - if colab_mode: - util.copy_files_to_server(opt['ssh_server'], opt['ssh_username'], opt['ssh_password'], - colab_imgs_to_copy, - os.path.join(opt['remote_path'], 'val_images', img_base_name)) + avg_psnr = avg_psnr / idx + avg_fea_loss = avg_fea_loss / idx + + # log + logger.info('# Validation # PSNR: {:.4e} Fea: {:.4e}'.format(avg_psnr, avg_fea_loss)) + # tensorboard logger + if opt['use_tb_logger'] and 'debug' not in opt['name'] and rank <= 0: + tb_logger.add_scalar('val_psnr', avg_psnr, current_step) + tb_logger.add_scalar('val_fea', avg_fea_loss, current_step) + + if rank <= 0: + logger.info('Saving the final model.') + model.save('latest') + logger.info('End of training.') + tb_logger.close() + +# TODO: Integrate with above main by putting this into an object and splitting up business logic. +def yielding_main(opt, launcher='none', trainer_id=0, all_networks={}): + #### distributed training settings + if len(opt['gpu_ids']) == 1 and torch.cuda.device_count() > 1: + gpu = input('I noticed you have multiple GPUs. Starting two jobs on the same GPU sucks. Please confirm which GPU' + 'you want to use. Press enter to use the specified one [%s]' % (opt['gpu_ids'])) + if gpu: + opt['gpu_ids'] = [int(gpu)] + if launcher == 'none': # disabled distributed training + opt['dist'] = False + rank = -1 + print('Disabled distributed training.') + + else: + opt['dist'] = True + init_dist() + world_size = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + + #### loading resume state if exists + if opt['path'].get('resume_state', None): + # distributed resuming: all load into default GPU + device_id = torch.cuda.current_device() + resume_state = torch.load(opt['path']['resume_state'], + map_location=lambda storage, loc: storage.cuda(device_id)) + option.check_resume(opt, resume_state['iter']) # check resume options + else: + resume_state = None + + #### mkdir and loggers + if rank <= 0: # normal training (rank -1) OR distributed training (rank 0) + if resume_state is None: + util.mkdir_and_rename( + opt['path']['experiments_root']) # rename experiment folder if exists + util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root' and path is not None + and 'pretrain_model' not in key and 'resume' not in key)) + + # config loggers. Before it, the log will not work + util.setup_logger('base', opt['path']['log'], 'train_' + opt['name'], level=logging.INFO, + screen=True, tofile=True) + logger = logging.getLogger('base') + logger.info(option.dict2str(opt)) + # tensorboard logger + if opt['use_tb_logger'] and 'debug' not in opt['name']: + tb_logger_path = os.path.join(opt['path']['experiments_root'], 'tb_logger') + version = float(torch.__version__[0:3]) + if version >= 1.1: # PyTorch 1.1 + from torch.utils.tensorboard import SummaryWriter + else: + logger.info( + 'You are using PyTorch {}. Tensorboard will use [tensorboardX]'.format(version)) + from tensorboardX import SummaryWriter + tb_logger = SummaryWriter(log_dir=tb_logger_path) + else: + util.setup_logger('base', opt['path']['log'], 'train', level=logging.INFO, screen=True) + logger = logging.getLogger('base') + + # convert to NoneDict, which returns None for missing keys + opt = option.dict_to_nonedict(opt) + + #### random seed + seed = opt['train']['manual_seed'] + if seed is None: + seed = random.randint(1, 10000) + if rank <= 0: + logger.info('Random seed: {}'.format(seed)) + util.set_random_seed(seed) + + torch.backends.cudnn.benchmark = True + # torch.backends.cudnn.deterministic = True + # torch.autograd.set_detect_anomaly(True) + + # Save the compiled opt dict to the global loaded_options variable. + util.loaded_options = opt + + #### create train and val dataloader + dataset_ratio = 1 # enlarge the size of each epoch + for phase, dataset_opt in opt['datasets'].items(): + if phase == 'train': + train_set = create_dataset(dataset_opt) + train_size = int(math.ceil(len(train_set) / dataset_opt['batch_size'])) + total_iters = int(opt['train']['niter']) + total_epochs = int(math.ceil(total_iters / train_size)) + if opt['dist']: + train_sampler = DistIterSampler(train_set, world_size, rank, dataset_ratio) + total_epochs = int(math.ceil(total_iters / (train_size * dataset_ratio))) + else: + train_sampler = None + train_loader = create_dataloader(train_set, dataset_opt, opt, train_sampler) + if rank <= 0: + logger.info('Number of train images: {:,d}, iters: {:,d}'.format( + len(train_set), train_size)) + logger.info('Total epochs needed: {:d} for iters {:,d}'.format( + total_epochs, total_iters)) + elif phase == 'val': + val_set = create_dataset(dataset_opt) + val_loader = create_dataloader(val_set, dataset_opt, opt, None) + if rank <= 0: + logger.info('Number of val images in [{:s}]: {:d}'.format( + dataset_opt['name'], len(val_set))) + else: + raise NotImplementedError('Phase [{:s}] is not recognized.'.format(phase)) + assert train_loader is not None + + #### create model + model = ExtensibleTrainer(opt, all_networks) + + #### resume training + if resume_state: + logger.info('Resuming training from epoch: {}, iter: {}.'.format( + resume_state['epoch'], resume_state['iter'])) + + start_epoch = resume_state['epoch'] + current_step = resume_state['iter'] + model.resume_training(resume_state, 'amp_opt_level' in opt.keys()) # handle optimizers and schedulers + else: + current_step = -1 if 'start_step' not in opt.keys() else opt['start_step'] + start_epoch = 0 + if 'force_start_step' in opt.keys(): + current_step = opt['force_start_step'] + + #### training + logger.info('Start training from epoch: {:d}, iter: {:d}'.format(start_epoch, current_step)) + for epoch in range(start_epoch, total_epochs + 1): + if opt['dist']: + train_sampler.set_epoch(epoch) + tq_ldr = tqdm(train_loader, position=trainer_id) + + _t = time() + _profile = False + for train_data in tq_ldr: + # Yielding supports multi-modal trainer which operates multiple train.py instances. + yield model + + if _profile: + print("Data fetch: %f" % (time() - _t)) + _t = time() + + current_step += 1 + if current_step > total_iters: + break + #### update learning rate + model.update_learning_rate(current_step, warmup_iter=opt['train']['warmup_iter']) + + #### training + if _profile: + print("Update LR: %f" % (time() - _t)) + _t = time() + model.feed_data(train_data) + model.optimize_parameters(current_step) + if _profile: + print("Model feed + step: %f" % (time() - _t)) + _t = time() + + #### log + if current_step % opt['logger']['print_freq'] == 0 and rank <= 0: + logs = model.get_current_log(current_step) + message = '[epoch:{:3d}, iter:{:8,d}, lr:('.format(epoch, current_step) + for v in model.get_current_learning_rate(): + message += '{:.3e},'.format(v) + message += ')] ' + for k, v in logs.items(): + if 'histogram' in k: + tb_logger.add_histogram(k, v, current_step) + elif isinstance(v, dict): + tb_logger.add_scalars(k, v, current_step) + else: + message += '{:s}: {:.4e} '.format(k, v) + # tensorboard logger + if opt['use_tb_logger'] and 'debug' not in opt['name']: + tb_logger.add_scalar(k, v, current_step) + logger.info(message) + + #### save models and training states + if current_step % opt['logger']['save_checkpoint_freq'] == 0: + if rank <= 0: + logger.info('Saving models and training states.') + model.save(current_step) + model.save_training_state(epoch, current_step) + if 'alt_path' in opt['path'].keys(): + import shutil + print("Synchronizing tb_logger to alt_path..") + alt_tblogger = os.path.join(opt['path']['alt_path'], "tb_logger") + shutil.rmtree(alt_tblogger, ignore_errors=True) + shutil.copytree(tb_logger_path, alt_tblogger) + + #### validation + if opt['datasets'].get('val', None) and current_step % opt['train']['val_freq'] == 0: + if opt['model'] in ['sr', 'srgan', 'corruptgan', 'spsrgan', 'extensibletrainer'] and rank <= 0: # image restoration validation + avg_psnr = 0. + avg_fea_loss = 0. + idx = 0 + val_tqdm = tqdm(val_loader) + for val_data in val_tqdm: + idx += 1 + for b in range(len(val_data['LQ_path'])): + img_name = os.path.splitext(os.path.basename(val_data['LQ_path'][b]))[0] + img_dir = os.path.join(opt['path']['val_images'], img_name) + util.mkdir(img_dir) + + model.feed_data(val_data) + model.test() + + visuals = model.get_current_visuals() + if visuals is None: + continue + + # calculate PSNR + sr_img = util.tensor2img(visuals['rlt'][b]) # uint8 + gt_img = util.tensor2img(visuals['GT'][b]) # uint8 + sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale']) + avg_psnr += util.calculate_psnr(sr_img, gt_img) + + # calculate fea loss + avg_fea_loss += model.compute_fea_loss(visuals['rlt'][b], visuals['GT'][b]) + + # Save SR images for reference + img_base_name = '{:s}_{:d}.png'.format(img_name, current_step) + save_img_path = os.path.join(img_dir, img_base_name) + util.save_img(sr_img, save_img_path) avg_psnr = avg_psnr / idx avg_fea_loss = avg_fea_loss / idx @@ -297,4 +501,9 @@ def main(): if __name__ == '__main__': - main() + parser = argparse.ArgumentParser() + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_imgset_chained_structured_trans_invariance.yml') + parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') + args = parser.parse_args() + opt = option.parse(args.opt, is_train=True) + main(opt, args.launcher) diff --git a/codes/train2.py b/codes/train2.py index e6348a72..90933d43 100644 --- a/codes/train2.py +++ b/codes/train2.py @@ -30,7 +30,7 @@ def init_dist(backend='nccl', **kwargs): def main(): #### options parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_chained_structured.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_multifaceted_chained.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args()