diff --git a/codes/train.py b/codes/train.py index dd16ebd8..0207a496 100644 --- a/codes/train.py +++ b/codes/train.py @@ -280,7 +280,9 @@ def main(): if 'alt_path' in opt['path'].keys(): import shutil print("Synchronizing tb_logger to alt_path..") - shutil.copytree(tb_logger_path, opt['path']['alt_path']) + alt_tblogger = os.path.join(tb_logger_path, "tb_logger") + shutil.rmtree(alt_tblogger, ignore_errors=True) + shutil.copytree(alt_tblogger, opt['path']['alt_path']) if rank <= 0: logger.info('Saving the final model.')