diff --git a/codes/train.py b/codes/train.py index 2e737d3f..f18cb224 100644 --- a/codes/train.py +++ b/codes/train.py @@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs): def main(): #### options parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_spsr3_gan.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_ssgr1.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) @@ -179,6 +179,10 @@ def main(): print("Data fetch: %f" % (time() - _t)) _t = time() + #tb_logger.add_graph(model.netsG['generator'].module, [train_data['LQ'].to('cuda'), + # train_data['lq_fullsize_ref'].float().to('cuda'), + # train_data['lq_center'].to('cuda')]) + current_step += 1 if current_step > total_iters: break @@ -214,6 +218,20 @@ def main(): tb_logger.add_scalar(k, v, current_step) if rank <= 0: logger.info(message) + + #### save models and training states + if current_step % opt['logger']['save_checkpoint_freq'] == 0: + if rank <= 0: + logger.info('Saving models and training states.') + model.save(current_step) + model.save_training_state(epoch, current_step) + if 'alt_path' in opt['path'].keys(): + import shutil + print("Synchronizing tb_logger to alt_path..") + alt_tblogger = os.path.join(opt['path']['alt_path'], "tb_logger") + shutil.rmtree(alt_tblogger, ignore_errors=True) + shutil.copytree(tb_logger_path, alt_tblogger) + #### validation if opt['datasets'].get('val', None) and current_step % opt['train']['val_freq'] == 0: if opt['model'] in ['sr', 'srgan', 'corruptgan', 'spsrgan', 'extensibletrainer'] and rank <= 0: # image restoration validation @@ -272,19 +290,6 @@ def main(): #tb_logger.add_scalar('val_psnr', avg_psnr, current_step) tb_logger.add_scalar('val_fea', avg_fea_loss, current_step) - #### save models and training states - if current_step % opt['logger']['save_checkpoint_freq'] == 0: - if rank <= 0: - logger.info('Saving models and training states.') - model.save(current_step) - model.save_training_state(epoch, current_step) - if 'alt_path' in opt['path'].keys(): - import shutil - print("Synchronizing tb_logger to alt_path..") - alt_tblogger = os.path.join(opt['path']['alt_path'], "tb_logger") - shutil.rmtree(alt_tblogger, ignore_errors=True) - shutil.copytree(tb_logger_path, alt_tblogger) - if rank <= 0: logger.info('Saving the final model.') model.save('latest')