mod trainer to be very explicit about the fact that loading models and state together dont work, but allow it

This commit is contained in:
James Betker 2021-10-28 22:32:42 -06:00
parent bb0a0c8264
commit 2afea126d7
2 changed files with 22 additions and 13 deletions

View File

@ -43,7 +43,6 @@ class Trainer:
device_id = torch.cuda.current_device() device_id = torch.cuda.current_device()
resume_state = torch.load(opt['path']['resume_state'], resume_state = torch.load(opt['path']['resume_state'],
map_location=lambda storage, loc: storage.cuda(device_id)) map_location=lambda storage, loc: storage.cuda(device_id))
option.check_resume(opt, resume_state['iter']) # check resume options
else: else:
resume_state = None resume_state = None
@ -64,18 +63,15 @@ class Trainer:
# tensorboard logger # tensorboard logger
if opt['use_tb_logger'] and 'debug' not in opt['name']: if opt['use_tb_logger'] and 'debug' not in opt['name']:
self.tb_logger_path = os.path.join(opt['path']['experiments_root'], 'tb_logger') self.tb_logger_path = os.path.join(opt['path']['experiments_root'], 'tb_logger')
version = float(torch.__version__[0:3]) from torch.utils.tensorboard import SummaryWriter
if version >= 1.1: # PyTorch 1.1
from torch.utils.tensorboard import SummaryWriter
else:
self.self.logger.info(
'You are using PyTorch {}. Tensorboard will use [tensorboardX]'.format(version))
from tensorboardX import SummaryWriter
self.tb_logger = SummaryWriter(log_dir=self.tb_logger_path) self.tb_logger = SummaryWriter(log_dir=self.tb_logger_path)
else: else:
util.setup_logger('base', opt['path']['log'], 'train', level=logging.INFO, screen=True) util.setup_logger('base', opt['path']['log'], 'train', level=logging.INFO, screen=True)
self.logger = logging.getLogger('base') self.logger = logging.getLogger('base')
if resume_state is not None:
option.check_resume(opt, resume_state['iter']) # check resume options
# convert to NoneDict, which returns None for missing keys # convert to NoneDict, which returns None for missing keys
opt = option.dict_to_nonedict(opt) opt = option.dict_to_nonedict(opt)
self.opt = opt self.opt = opt
@ -284,7 +280,7 @@ class Trainer:
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_gpt_asr_mass_distill.yml') parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_noisy_audio_clips_classifier.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0) parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args() args = parser.parse_args()

View File

@ -103,7 +103,20 @@ def check_resume(opt, resume_iter):
# Automatically fill in the network paths for a given resume iteration. # Automatically fill in the network paths for a given resume iteration.
for k in opt['networks'].keys(): for k in opt['networks'].keys():
pt_key = 'pretrain_model_%s' % (k,) pt_key = 'pretrain_model_%s' % (k,)
assert pt_key not in opt['path'].keys() # There's no real reason to load from a training_state AND a model. if pt_key in opt['path'].keys():
opt['path'][pt_key] = osp.join(opt['path']['models'], # This is a dicey, error prone situation that has bitten me in both ways it can be handled. Opt for
'{}_{}.pth'.format(resume_iter, k)) # a big, verbose error message.
logger.info('Set model [%s] to %s' % (k, opt['path'][pt_key])) print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
print("!!!!!!!!!!!!!!WARNING!! YOU SPECIFIED A PRETRAINED MODEL PATH!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
print("!!!!!!!!!!!!!!!!!!!!!!!!AND A RESUME STATE PATH. THERE IS NO!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
print("!!!!!!!!!!!!!!!!!!!!!!!!GOOD WAY TO HANDLE THIS SO WE JUST IGNORE!!!!!!!!!!!!!!!!!!!!!!!!!!!")
print("!!!!!!!!!!!!!!!!!!!!!!!!THE MODEL PATH!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
else:
opt['path'][pt_key] = osp.join(opt['path']['models'],
'{}_{}.pth'.format(resume_iter, k))
logger.info('Set model [%s] to %s' % (k, opt['path'][pt_key]))