|
|
|
@ -111,11 +111,6 @@ class Trainer:
|
|
|
|
|
self.dataset_debugger = get_dataset_debugger(dataset_opt)
|
|
|
|
|
if self.dataset_debugger is not None and resume_state is not None:
|
|
|
|
|
self.dataset_debugger.load_state(opt_get(resume_state, ['dataset_debugger_state'], {}))
|
|
|
|
|
# it will indefinitely try to train if your batch size is larger than your dataset
|
|
|
|
|
# could just whine when generating the YAML rather than assert here
|
|
|
|
|
if len(self.train_set) < dataset_opt['batch_size']:
|
|
|
|
|
dataset_opt['batch_size'] = len(self.train_set)
|
|
|
|
|
print("dataset size is less than batch size, consider reducing your batch size, or increasing your dataset.")
|
|
|
|
|
train_size = int(math.ceil(len(self.train_set) / dataset_opt['batch_size']))
|
|
|
|
|
total_iters = int(opt['train']['niter'])
|
|
|
|
|
self.total_epochs = int(math.ceil(total_iters / train_size))
|
|
|
|
@ -133,6 +128,9 @@ class Trainer:
|
|
|
|
|
self.logger.info('Total epochs needed: {:d} for iters {:,d}'.format(
|
|
|
|
|
self.total_epochs, total_iters))
|
|
|
|
|
elif phase == 'val':
|
|
|
|
|
if not opt_get(opt, ['eval', 'pure'], False):
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
self.val_set, collate_fn = create_dataset(dataset_opt, return_collate=True)
|
|
|
|
|
self.val_loader = create_dataloader(self.val_set, dataset_opt, opt, None, collate_fn=collate_fn)
|
|
|
|
|
if self.rank <= 0:
|
|
|
|
|