From 0ee0f46596158aa1d6b8f95b1e7637785c616ee3 Mon Sep 17 00:00:00 2001 From: mrq Date: Thu, 9 Mar 2023 00:29:25 +0000 Subject: [PATCH] . --- codes/train.py | 8 +++----- codes/utils/options.py | 4 ++-- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/codes/train.py b/codes/train.py index 8544070d..3b2b26e4 100644 --- a/codes/train.py +++ b/codes/train.py @@ -111,11 +111,6 @@ class Trainer: self.dataset_debugger = get_dataset_debugger(dataset_opt) if self.dataset_debugger is not None and resume_state is not None: self.dataset_debugger.load_state(opt_get(resume_state, ['dataset_debugger_state'], {})) - # it will indefinitely try to train if your batch size is larger than your dataset - # could just whine when generating the YAML rather than assert here - if len(self.train_set) < dataset_opt['batch_size']: - dataset_opt['batch_size'] = len(self.train_set) - print("dataset size is less than batch size, consider reducing your batch size, or increasing your dataset.") train_size = int(math.ceil(len(self.train_set) / dataset_opt['batch_size'])) total_iters = int(opt['train']['niter']) self.total_epochs = int(math.ceil(total_iters / train_size)) @@ -133,6 +128,9 @@ class Trainer: self.logger.info('Total epochs needed: {:d} for iters {:,d}'.format( self.total_epochs, total_iters)) elif phase == 'val': + if not opt_get(opt, ['eval', 'pure'], False): + continue + self.val_set, collate_fn = create_dataset(dataset_opt, return_collate=True) self.val_loader = create_dataloader(self.val_set, dataset_opt, opt, None, collate_fn=collate_fn) if self.rank <= 0: diff --git a/codes/utils/options.py b/codes/utils/options.py index 36c2dec3..b6abdf2e 100644 --- a/codes/utils/options.py +++ b/codes/utils/options.py @@ -39,9 +39,9 @@ def parse(opt_path, is_train=True): opt['path'][key] = osp.expanduser(path) else: opt['path'] = {} - opt['path']['root'] = osp.abspath(os.getcwd()) #osp.join(__file__, osp.pardir, osp.pardir, osp.pardir)) + opt['path']['root'] = "./" if is_train: - experiments_root = osp.join(opt['path']['root'], 'training', opt['name']) + experiments_root = osp.join(opt['path']['root'], 'training', opt['name'], "finetune") opt['path']['experiments_root'] = experiments_root opt['path']['models'] = osp.join(experiments_root, 'models') opt['path']['training_state'] = osp.join(experiments_root, 'training_state')