From 7d1220e83e751cebe644b59018babf49e247c965 Mon Sep 17 00:00:00 2001 From: mrq Date: Thu, 23 Feb 2023 15:38:04 +0000 Subject: [PATCH] forgot to mult by batch size --- src/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/utils.py b/src/utils.py index 6089f15..3f7cac9 100755 --- a/src/utils.py +++ b/src/utils.py @@ -445,6 +445,7 @@ class TrainingState(): with open(config_path, 'r') as file: self.config = yaml.safe_load(file) + self.batch_size = self.config['datasets']['train']['batch_size'] self.dataset_path = self.config['datasets']['train']['path'] with open(self.dataset_path, 'r', encoding="utf-8") as f: self.dataset_size = len(f.readlines()) @@ -453,7 +454,7 @@ class TrainingState(): self.its = self.config['train']['niter'] self.epoch = 0 - self.epochs = int(self.its/self.dataset_size) + self.epochs = int(self.its*self.batch_size/self.dataset_size) self.checkpoint = 0 self.checkpoints = int(self.its / self.config['logger']['save_checkpoint_freq'])