From 6c284ef8ec4c4769de3181d90ac96ff63581ef55 Mon Sep 17 00:00:00 2001 From: mrq Date: Sat, 18 Feb 2023 03:27:04 +0000 Subject: [PATCH] oops --- codes/train.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/codes/train.py b/codes/train.py index cf63dda2..88d1caac 100644 --- a/codes/train.py +++ b/codes/train.py @@ -112,8 +112,9 @@ class Trainer: self.dataset_debugger.load_state(opt_get(resume_state, ['dataset_debugger_state'], {})) # it will indefinitely try to train if your batch size is larger than your dataset # could just whine when generating the YAML rather than assert here - if len(self.train_set) <= dataset_opt['batch_size']: - raise Exception("dataset size is less than batch size, consider reducing your batch size, or increasing your dataset.") + if len(self.train_set) < dataset_opt['batch_size']: + dataset_opt['batch_size'] = len(self.train_set) + print("dataset size is less than batch size, consider reducing your batch size, or increasing your dataset.") train_size = int(math.ceil(len(self.train_set) / dataset_opt['batch_size'])) total_iters = int(opt['train']['niter']) self.total_epochs = int(math.ceil(total_iters / train_size))