remotes/1719357675652877588/tmp_refs/heads/master
mrq 2023-03-09 00:29:25 +07:00
parent 6eb7ebf847
commit 0ee0f46596
2 changed files with 5 additions and 7 deletions

@ -111,11 +111,6 @@ class Trainer:
self.dataset_debugger = get_dataset_debugger(dataset_opt)
if self.dataset_debugger is not None and resume_state is not None:
self.dataset_debugger.load_state(opt_get(resume_state, ['dataset_debugger_state'], {}))
# it will indefinitely try to train if your batch size is larger than your dataset
# could just whine when generating the YAML rather than assert here
if len(self.train_set) < dataset_opt['batch_size']:
dataset_opt['batch_size'] = len(self.train_set)
print("dataset size is less than batch size, consider reducing your batch size, or increasing your dataset.")
train_size = int(math.ceil(len(self.train_set) / dataset_opt['batch_size']))
total_iters = int(opt['train']['niter'])
self.total_epochs = int(math.ceil(total_iters / train_size))
@ -133,6 +128,9 @@ class Trainer:
self.logger.info('Total epochs needed: {:d} for iters {:,d}'.format(
self.total_epochs, total_iters))
elif phase == 'val':
if not opt_get(opt, ['eval', 'pure'], False):
continue
self.val_set, collate_fn = create_dataset(dataset_opt, return_collate=True)
self.val_loader = create_dataloader(self.val_set, dataset_opt, opt, None, collate_fn=collate_fn)
if self.rank <= 0:

@ -39,9 +39,9 @@ def parse(opt_path, is_train=True):
opt['path'][key] = osp.expanduser(path)
else:
opt['path'] = {}
opt['path']['root'] = osp.abspath(os.getcwd()) #osp.join(__file__, osp.pardir, osp.pardir, osp.pardir))
opt['path']['root'] = "./"
if is_train:
experiments_root = osp.join(opt['path']['root'], 'training', opt['name'])
experiments_root = osp.join(opt['path']['root'], 'training', opt['name'], "finetune")
opt['path']['experiments_root'] = experiments_root
opt['path']['models'] = osp.join(experiments_root, 'models')
opt['path']['training_state'] = osp.join(experiments_root, 'training_state')