diff --git a/codes/train.py b/codes/train.py index befd5d97..02f8f2bb 100644 --- a/codes/train.py +++ b/codes/train.py @@ -3,6 +3,7 @@ import math import argparse import random import logging +import shutil from tqdm import tqdm import torch @@ -72,6 +73,12 @@ def main(): logger.info(option.dict2str(opt)) # tensorboard logger if opt['use_tb_logger'] and 'debug' not in opt['name']: + # If not resuming, delete the existing logs. Tensorboard doesn't do too great with these. + if 'resume_state' not in opt['path'].keys(): + tb_logger_path = '../tb_logger/' + opt['name'] + if os.path.exists(tb_logger_path) and os.path.isdir(tb_logger_path): + shutil.rmtree(tb_logger_path) + version = float(torch.__version__[0:3]) if version >= 1.1: # PyTorch 1.1 from torch.utils.tensorboard import SummaryWriter @@ -79,7 +86,7 @@ def main(): logger.info( 'You are using PyTorch {}. Tensorboard will use [tensorboardX]'.format(version)) from tensorboardX import SummaryWriter - tb_logger = SummaryWriter(log_dir='../tb_logger/' + opt['name']) + tb_logger = SummaryWriter(log_dir=tb_logger_path) else: util.setup_logger('base', opt['path']['log'], 'train', level=logging.INFO, screen=True) logger = logging.getLogger('base')