accurate epoch metric is now reported (based on samples processed / length of dataset's paths, rather than naive assumptions)

This commit is contained in:
mrq 2023-09-03 08:03:36 -05:00
parent 922404285c
commit 81b05dabb9
2 changed files with 8 additions and 3 deletions

View File

@ -222,6 +222,7 @@ class Engines(dict[str, Engine]):
self._global_step = 0
self._micro_step = 0
self._batch_size = 0
self._global_samples = 0
@property
def global_step(self):
@ -235,6 +236,10 @@ class Engines(dict[str, Engine]):
def batch_size(self):
return self._batch_size
@property
def global_samples(self):
return self._global_samples
def gather_attribute(self, *args, **kwargs):
ret = {}
for engine in self.values():
@ -324,6 +329,7 @@ class Engines(dict[str, Engine]):
self._global_step = max(self._global_step, engine.global_step)
self._micro_step = max(self._micro_step, engine.micro_step)
self._batch_size = max(self._batch_size, engine.batch_size)
self._global_samples = max(self._global_samples, engine.global_samples)
def eval(self):
for engine in self.values():

View File

@ -251,9 +251,8 @@ def train(
#batch = to_device(batch, torch.cuda.current_device())
stats = engines.step(batch=batch, feeder=train_feeder)
iteration = stats['global_step'] # * cfg.hyperparameters.gradient_accumulation_steps
stats['it'] = iteration
stats['epoch'] = iteration * cfg.hyperparameters.gradient_accumulation_steps / len(train_dl)
stats['it'] = stats['global_step']
stats['epoch'] = engines.global_samples / len(train_dl.dataset.paths)
stats['batch'] = {
'size': len(batch['text']),