diff --git a/codes/train.py b/codes/train.py index 0207a496..1ea05fa8 100644 --- a/codes/train.py +++ b/codes/train.py @@ -280,9 +280,9 @@ def main(): if 'alt_path' in opt['path'].keys(): import shutil print("Synchronizing tb_logger to alt_path..") - alt_tblogger = os.path.join(tb_logger_path, "tb_logger") + alt_tblogger = os.path.join(opt['path']['alt_path'], "tb_logger") shutil.rmtree(alt_tblogger, ignore_errors=True) - shutil.copytree(alt_tblogger, opt['path']['alt_path']) + shutil.copytree(tb_logger_path, alt_tblogger) if rank <= 0: logger.info('Saving the final model.')