forked from mrq/DL-Art-School
Clear out tensorboard on job restart.
This commit is contained in:
parent
b7857f35c3
commit
f027e888ed
|
@ -3,6 +3,7 @@ import math
|
||||||
import argparse
|
import argparse
|
||||||
import random
|
import random
|
||||||
import logging
|
import logging
|
||||||
|
import shutil
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
@ -72,6 +73,12 @@ def main():
|
||||||
logger.info(option.dict2str(opt))
|
logger.info(option.dict2str(opt))
|
||||||
# tensorboard logger
|
# tensorboard logger
|
||||||
if opt['use_tb_logger'] and 'debug' not in opt['name']:
|
if opt['use_tb_logger'] and 'debug' not in opt['name']:
|
||||||
|
# If not resuming, delete the existing logs. Tensorboard doesn't do too great with these.
|
||||||
|
if 'resume_state' not in opt['path'].keys():
|
||||||
|
tb_logger_path = '../tb_logger/' + opt['name']
|
||||||
|
if os.path.exists(tb_logger_path) and os.path.isdir(tb_logger_path):
|
||||||
|
shutil.rmtree(tb_logger_path)
|
||||||
|
|
||||||
version = float(torch.__version__[0:3])
|
version = float(torch.__version__[0:3])
|
||||||
if version >= 1.1: # PyTorch 1.1
|
if version >= 1.1: # PyTorch 1.1
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
@ -79,7 +86,7 @@ def main():
|
||||||
logger.info(
|
logger.info(
|
||||||
'You are using PyTorch {}. Tensorboard will use [tensorboardX]'.format(version))
|
'You are using PyTorch {}. Tensorboard will use [tensorboardX]'.format(version))
|
||||||
from tensorboardX import SummaryWriter
|
from tensorboardX import SummaryWriter
|
||||||
tb_logger = SummaryWriter(log_dir='../tb_logger/' + opt['name'])
|
tb_logger = SummaryWriter(log_dir=tb_logger_path)
|
||||||
else:
|
else:
|
||||||
util.setup_logger('base', opt['path']['log'], 'train', level=logging.INFO, screen=True)
|
util.setup_logger('base', opt['path']['log'], 'train', level=logging.INFO, screen=True)
|
||||||
logger = logging.getLogger('base')
|
logger = logging.getLogger('base')
|
||||||
|
|
Loading…
Reference in New Issue
Block a user