|
|
@ -445,6 +445,7 @@ class TrainingState():
|
|
|
|
with open(config_path, 'r') as file:
|
|
|
|
with open(config_path, 'r') as file:
|
|
|
|
self.config = yaml.safe_load(file)
|
|
|
|
self.config = yaml.safe_load(file)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.batch_size = self.config['datasets']['train']['batch_size']
|
|
|
|
self.dataset_path = self.config['datasets']['train']['path']
|
|
|
|
self.dataset_path = self.config['datasets']['train']['path']
|
|
|
|
with open(self.dataset_path, 'r', encoding="utf-8") as f:
|
|
|
|
with open(self.dataset_path, 'r', encoding="utf-8") as f:
|
|
|
|
self.dataset_size = len(f.readlines())
|
|
|
|
self.dataset_size = len(f.readlines())
|
|
|
@ -453,7 +454,7 @@ class TrainingState():
|
|
|
|
self.its = self.config['train']['niter']
|
|
|
|
self.its = self.config['train']['niter']
|
|
|
|
|
|
|
|
|
|
|
|
self.epoch = 0
|
|
|
|
self.epoch = 0
|
|
|
|
self.epochs = int(self.its/self.dataset_size)
|
|
|
|
self.epochs = int(self.its*self.batch_size/self.dataset_size)
|
|
|
|
|
|
|
|
|
|
|
|
self.checkpoint = 0
|
|
|
|
self.checkpoint = 0
|
|
|
|
self.checkpoints = int(self.its / self.config['logger']['save_checkpoint_freq'])
|
|
|
|
self.checkpoints = int(self.its / self.config['logger']['save_checkpoint_freq'])
|
|
|
|