mod trainer to be very explicit about the fact that loading models and state together dont work, but allow it
This commit is contained in:
parent
bb0a0c8264
commit
2afea126d7
|
@ -43,7 +43,6 @@ class Trainer:
|
||||||
device_id = torch.cuda.current_device()
|
device_id = torch.cuda.current_device()
|
||||||
resume_state = torch.load(opt['path']['resume_state'],
|
resume_state = torch.load(opt['path']['resume_state'],
|
||||||
map_location=lambda storage, loc: storage.cuda(device_id))
|
map_location=lambda storage, loc: storage.cuda(device_id))
|
||||||
option.check_resume(opt, resume_state['iter']) # check resume options
|
|
||||||
else:
|
else:
|
||||||
resume_state = None
|
resume_state = None
|
||||||
|
|
||||||
|
@ -64,18 +63,15 @@ class Trainer:
|
||||||
# tensorboard logger
|
# tensorboard logger
|
||||||
if opt['use_tb_logger'] and 'debug' not in opt['name']:
|
if opt['use_tb_logger'] and 'debug' not in opt['name']:
|
||||||
self.tb_logger_path = os.path.join(opt['path']['experiments_root'], 'tb_logger')
|
self.tb_logger_path = os.path.join(opt['path']['experiments_root'], 'tb_logger')
|
||||||
version = float(torch.__version__[0:3])
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
if version >= 1.1: # PyTorch 1.1
|
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
|
||||||
else:
|
|
||||||
self.self.logger.info(
|
|
||||||
'You are using PyTorch {}. Tensorboard will use [tensorboardX]'.format(version))
|
|
||||||
from tensorboardX import SummaryWriter
|
|
||||||
self.tb_logger = SummaryWriter(log_dir=self.tb_logger_path)
|
self.tb_logger = SummaryWriter(log_dir=self.tb_logger_path)
|
||||||
else:
|
else:
|
||||||
util.setup_logger('base', opt['path']['log'], 'train', level=logging.INFO, screen=True)
|
util.setup_logger('base', opt['path']['log'], 'train', level=logging.INFO, screen=True)
|
||||||
self.logger = logging.getLogger('base')
|
self.logger = logging.getLogger('base')
|
||||||
|
|
||||||
|
if resume_state is not None:
|
||||||
|
option.check_resume(opt, resume_state['iter']) # check resume options
|
||||||
|
|
||||||
# convert to NoneDict, which returns None for missing keys
|
# convert to NoneDict, which returns None for missing keys
|
||||||
opt = option.dict_to_nonedict(opt)
|
opt = option.dict_to_nonedict(opt)
|
||||||
self.opt = opt
|
self.opt = opt
|
||||||
|
@ -284,7 +280,7 @@ class Trainer:
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_gpt_asr_mass_distill.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_noisy_audio_clips_classifier.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||||
parser.add_argument('--local_rank', type=int, default=0)
|
parser.add_argument('--local_rank', type=int, default=0)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
|
@ -103,7 +103,20 @@ def check_resume(opt, resume_iter):
|
||||||
# Automatically fill in the network paths for a given resume iteration.
|
# Automatically fill in the network paths for a given resume iteration.
|
||||||
for k in opt['networks'].keys():
|
for k in opt['networks'].keys():
|
||||||
pt_key = 'pretrain_model_%s' % (k,)
|
pt_key = 'pretrain_model_%s' % (k,)
|
||||||
assert pt_key not in opt['path'].keys() # There's no real reason to load from a training_state AND a model.
|
if pt_key in opt['path'].keys():
|
||||||
opt['path'][pt_key] = osp.join(opt['path']['models'],
|
# This is a dicey, error prone situation that has bitten me in both ways it can be handled. Opt for
|
||||||
'{}_{}.pth'.format(resume_iter, k))
|
# a big, verbose error message.
|
||||||
logger.info('Set model [%s] to %s' % (k, opt['path'][pt_key]))
|
print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
|
||||||
|
print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
|
||||||
|
print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
|
||||||
|
print("!!!!!!!!!!!!!!WARNING!! YOU SPECIFIED A PRETRAINED MODEL PATH!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
|
||||||
|
print("!!!!!!!!!!!!!!!!!!!!!!!!AND A RESUME STATE PATH. THERE IS NO!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
|
||||||
|
print("!!!!!!!!!!!!!!!!!!!!!!!!GOOD WAY TO HANDLE THIS SO WE JUST IGNORE!!!!!!!!!!!!!!!!!!!!!!!!!!!")
|
||||||
|
print("!!!!!!!!!!!!!!!!!!!!!!!!THE MODEL PATH!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
|
||||||
|
print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
|
||||||
|
print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
|
||||||
|
print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
|
||||||
|
else:
|
||||||
|
opt['path'][pt_key] = osp.join(opt['path']['models'],
|
||||||
|
'{}_{}.pth'.format(resume_iter, k))
|
||||||
|
logger.info('Set model [%s] to %s' % (k, opt['path'][pt_key]))
|
||||||
|
|
Loading…
Reference in New Issue
Block a user