.
This commit is contained in:
parent
6eb7ebf847
commit
0ee0f46596
|
@ -111,11 +111,6 @@ class Trainer:
|
|||
self.dataset_debugger = get_dataset_debugger(dataset_opt)
|
||||
if self.dataset_debugger is not None and resume_state is not None:
|
||||
self.dataset_debugger.load_state(opt_get(resume_state, ['dataset_debugger_state'], {}))
|
||||
# it will indefinitely try to train if your batch size is larger than your dataset
|
||||
# could just whine when generating the YAML rather than assert here
|
||||
if len(self.train_set) < dataset_opt['batch_size']:
|
||||
dataset_opt['batch_size'] = len(self.train_set)
|
||||
print("dataset size is less than batch size, consider reducing your batch size, or increasing your dataset.")
|
||||
train_size = int(math.ceil(len(self.train_set) / dataset_opt['batch_size']))
|
||||
total_iters = int(opt['train']['niter'])
|
||||
self.total_epochs = int(math.ceil(total_iters / train_size))
|
||||
|
@ -133,6 +128,9 @@ class Trainer:
|
|||
self.logger.info('Total epochs needed: {:d} for iters {:,d}'.format(
|
||||
self.total_epochs, total_iters))
|
||||
elif phase == 'val':
|
||||
if not opt_get(opt, ['eval', 'pure'], False):
|
||||
continue
|
||||
|
||||
self.val_set, collate_fn = create_dataset(dataset_opt, return_collate=True)
|
||||
self.val_loader = create_dataloader(self.val_set, dataset_opt, opt, None, collate_fn=collate_fn)
|
||||
if self.rank <= 0:
|
||||
|
|
|
@ -39,9 +39,9 @@ def parse(opt_path, is_train=True):
|
|||
opt['path'][key] = osp.expanduser(path)
|
||||
else:
|
||||
opt['path'] = {}
|
||||
opt['path']['root'] = osp.abspath(os.getcwd()) #osp.join(__file__, osp.pardir, osp.pardir, osp.pardir))
|
||||
opt['path']['root'] = "./"
|
||||
if is_train:
|
||||
experiments_root = osp.join(opt['path']['root'], 'training', opt['name'])
|
||||
experiments_root = osp.join(opt['path']['root'], 'training', opt['name'], "finetune")
|
||||
opt['path']['experiments_root'] = experiments_root
|
||||
opt['path']['models'] = osp.join(experiments_root, 'models')
|
||||
opt['path']['training_state'] = osp.join(experiments_root, 'training_state')
|
||||
|
|
Loading…
Reference in New Issue
Block a user