forked from mrq/DL-Art-School
oops
This commit is contained in:
parent
8db762fa17
commit
6c284ef8ec
|
@ -112,8 +112,9 @@ class Trainer:
|
||||||
self.dataset_debugger.load_state(opt_get(resume_state, ['dataset_debugger_state'], {}))
|
self.dataset_debugger.load_state(opt_get(resume_state, ['dataset_debugger_state'], {}))
|
||||||
# it will indefinitely try to train if your batch size is larger than your dataset
|
# it will indefinitely try to train if your batch size is larger than your dataset
|
||||||
# could just whine when generating the YAML rather than assert here
|
# could just whine when generating the YAML rather than assert here
|
||||||
if len(self.train_set) <= dataset_opt['batch_size']:
|
if len(self.train_set) < dataset_opt['batch_size']:
|
||||||
raise Exception("dataset size is less than batch size, consider reducing your batch size, or increasing your dataset.")
|
dataset_opt['batch_size'] = len(self.train_set)
|
||||||
|
print("dataset size is less than batch size, consider reducing your batch size, or increasing your dataset.")
|
||||||
train_size = int(math.ceil(len(self.train_set) / dataset_opt['batch_size']))
|
train_size = int(math.ceil(len(self.train_set) / dataset_opt['batch_size']))
|
||||||
total_iters = int(opt['train']['niter'])
|
total_iters = int(opt['train']['niter'])
|
||||||
self.total_epochs = int(math.ceil(total_iters / train_size))
|
self.total_epochs = int(math.ceil(total_iters / train_size))
|
||||||
|
|
Loading…
Reference in New Issue
Block a user