From 2afea126d73873f15f0e138e2a270e2e107c5f23 Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 28 Oct 2021 22:32:42 -0600 Subject: [PATCH] mod trainer to be very explicit about the fact that loading models and state together dont work, but allow it --- codes/train.py | 14 +++++--------- codes/utils/options.py | 21 +++++++++++++++++---- 2 files changed, 22 insertions(+), 13 deletions(-) diff --git a/codes/train.py b/codes/train.py index 58803dc6..d2e2b973 100644 --- a/codes/train.py +++ b/codes/train.py @@ -43,7 +43,6 @@ class Trainer: device_id = torch.cuda.current_device() resume_state = torch.load(opt['path']['resume_state'], map_location=lambda storage, loc: storage.cuda(device_id)) - option.check_resume(opt, resume_state['iter']) # check resume options else: resume_state = None @@ -64,18 +63,15 @@ class Trainer: # tensorboard logger if opt['use_tb_logger'] and 'debug' not in opt['name']: self.tb_logger_path = os.path.join(opt['path']['experiments_root'], 'tb_logger') - version = float(torch.__version__[0:3]) - if version >= 1.1: # PyTorch 1.1 - from torch.utils.tensorboard import SummaryWriter - else: - self.self.logger.info( - 'You are using PyTorch {}. Tensorboard will use [tensorboardX]'.format(version)) - from tensorboardX import SummaryWriter + from torch.utils.tensorboard import SummaryWriter self.tb_logger = SummaryWriter(log_dir=self.tb_logger_path) else: util.setup_logger('base', opt['path']['log'], 'train', level=logging.INFO, screen=True) self.logger = logging.getLogger('base') + if resume_state is not None: + option.check_resume(opt, resume_state['iter']) # check resume options + # convert to NoneDict, which returns None for missing keys opt = option.dict_to_nonedict(opt) self.opt = opt @@ -284,7 +280,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_gpt_asr_mass_distill.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_noisy_audio_clips_classifier.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args() diff --git a/codes/utils/options.py b/codes/utils/options.py index 5b152c00..613cb145 100644 --- a/codes/utils/options.py +++ b/codes/utils/options.py @@ -103,7 +103,20 @@ def check_resume(opt, resume_iter): # Automatically fill in the network paths for a given resume iteration. for k in opt['networks'].keys(): pt_key = 'pretrain_model_%s' % (k,) - assert pt_key not in opt['path'].keys() # There's no real reason to load from a training_state AND a model. - opt['path'][pt_key] = osp.join(opt['path']['models'], - '{}_{}.pth'.format(resume_iter, k)) - logger.info('Set model [%s] to %s' % (k, opt['path'][pt_key])) + if pt_key in opt['path'].keys(): + # This is a dicey, error prone situation that has bitten me in both ways it can be handled. Opt for + # a big, verbose error message. + print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") + print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") + print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") + print("!!!!!!!!!!!!!!WARNING!! YOU SPECIFIED A PRETRAINED MODEL PATH!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") + print("!!!!!!!!!!!!!!!!!!!!!!!!AND A RESUME STATE PATH. THERE IS NO!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") + print("!!!!!!!!!!!!!!!!!!!!!!!!GOOD WAY TO HANDLE THIS SO WE JUST IGNORE!!!!!!!!!!!!!!!!!!!!!!!!!!!") + print("!!!!!!!!!!!!!!!!!!!!!!!!THE MODEL PATH!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") + print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") + print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") + print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") + else: + opt['path'][pt_key] = osp.join(opt['path']['models'], + '{}_{}.pth'.format(resume_iter, k)) + logger.info('Set model [%s] to %s' % (k, opt['path'][pt_key]))