diff --git a/codes/train.py b/codes/train.py index 4c3e8761..dd16ebd8 100644 --- a/codes/train.py +++ b/codes/train.py @@ -277,6 +277,10 @@ def main(): logger.info('Saving models and training states.') model.save(current_step) model.save_training_state(epoch, current_step) + if 'alt_path' in opt['path'].keys(): + import shutil + print("Synchronizing tb_logger to alt_path..") + shutil.copytree(tb_logger_path, opt['path']['alt_path']) if rank <= 0: logger.info('Saving the final model.')