mod trainer to be very explicit about the fact that loading models and state together dont work, but allow it
This commit is contained in:
parent
bb0a0c8264
commit
2afea126d7
|
@ -43,7 +43,6 @@ class Trainer:
|
|||
device_id = torch.cuda.current_device()
|
||||
resume_state = torch.load(opt['path']['resume_state'],
|
||||
map_location=lambda storage, loc: storage.cuda(device_id))
|
||||
option.check_resume(opt, resume_state['iter']) # check resume options
|
||||
else:
|
||||
resume_state = None
|
||||
|
||||
|
@ -64,18 +63,15 @@ class Trainer:
|
|||
# tensorboard logger
|
||||
if opt['use_tb_logger'] and 'debug' not in opt['name']:
|
||||
self.tb_logger_path = os.path.join(opt['path']['experiments_root'], 'tb_logger')
|
||||
version = float(torch.__version__[0:3])
|
||||
if version >= 1.1: # PyTorch 1.1
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
else:
|
||||
self.self.logger.info(
|
||||
'You are using PyTorch {}. Tensorboard will use [tensorboardX]'.format(version))
|
||||
from tensorboardX import SummaryWriter
|
||||
self.tb_logger = SummaryWriter(log_dir=self.tb_logger_path)
|
||||
else:
|
||||
util.setup_logger('base', opt['path']['log'], 'train', level=logging.INFO, screen=True)
|
||||
self.logger = logging.getLogger('base')
|
||||
|
||||
if resume_state is not None:
|
||||
option.check_resume(opt, resume_state['iter']) # check resume options
|
||||
|
||||
# convert to NoneDict, which returns None for missing keys
|
||||
opt = option.dict_to_nonedict(opt)
|
||||
self.opt = opt
|
||||
|
@ -284,7 +280,7 @@ class Trainer:
|
|||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_gpt_asr_mass_distill.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_noisy_audio_clips_classifier.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
|
|
|
@ -103,7 +103,20 @@ def check_resume(opt, resume_iter):
|
|||
# Automatically fill in the network paths for a given resume iteration.
|
||||
for k in opt['networks'].keys():
|
||||
pt_key = 'pretrain_model_%s' % (k,)
|
||||
assert pt_key not in opt['path'].keys() # There's no real reason to load from a training_state AND a model.
|
||||
if pt_key in opt['path'].keys():
|
||||
# This is a dicey, error prone situation that has bitten me in both ways it can be handled. Opt for
|
||||
# a big, verbose error message.
|
||||
print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
|
||||
print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
|
||||
print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
|
||||
print("!!!!!!!!!!!!!!WARNING!! YOU SPECIFIED A PRETRAINED MODEL PATH!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
|
||||
print("!!!!!!!!!!!!!!!!!!!!!!!!AND A RESUME STATE PATH. THERE IS NO!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
|
||||
print("!!!!!!!!!!!!!!!!!!!!!!!!GOOD WAY TO HANDLE THIS SO WE JUST IGNORE!!!!!!!!!!!!!!!!!!!!!!!!!!!")
|
||||
print("!!!!!!!!!!!!!!!!!!!!!!!!THE MODEL PATH!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
|
||||
print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
|
||||
print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
|
||||
print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
|
||||
else:
|
||||
opt['path'][pt_key] = osp.join(opt['path']['models'],
|
||||
'{}_{}.pth'.format(resume_iter, k))
|
||||
logger.info('Set model [%s] to %s' % (k, opt['path'][pt_key]))
|
||||
|
|
Loading…
Reference in New Issue
Block a user