diff --git a/vall_e/engines/base.py b/vall_e/engines/base.py index 7545b3c..3a5c7d6 100755 --- a/vall_e/engines/base.py +++ b/vall_e/engines/base.py @@ -222,6 +222,7 @@ class Engines(dict[str, Engine]): self._global_step = 0 self._micro_step = 0 self._batch_size = 0 + self._global_samples = 0 @property def global_step(self): @@ -235,6 +236,10 @@ class Engines(dict[str, Engine]): def batch_size(self): return self._batch_size + @property + def global_samples(self): + return self._global_samples + def gather_attribute(self, *args, **kwargs): ret = {} for engine in self.values(): @@ -324,6 +329,7 @@ class Engines(dict[str, Engine]): self._global_step = max(self._global_step, engine.global_step) self._micro_step = max(self._micro_step, engine.micro_step) self._batch_size = max(self._batch_size, engine.batch_size) + self._global_samples = max(self._global_samples, engine.global_samples) def eval(self): for engine in self.values(): diff --git a/vall_e/utils/trainer.py b/vall_e/utils/trainer.py index 15d3828..cfb4c7c 100755 --- a/vall_e/utils/trainer.py +++ b/vall_e/utils/trainer.py @@ -251,9 +251,8 @@ def train( #batch = to_device(batch, torch.cuda.current_device()) stats = engines.step(batch=batch, feeder=train_feeder) - iteration = stats['global_step'] # * cfg.hyperparameters.gradient_accumulation_steps - stats['it'] = iteration - stats['epoch'] = iteration * cfg.hyperparameters.gradient_accumulation_steps / len(train_dl) + stats['it'] = stats['global_step'] + stats['epoch'] = engines.global_samples / len(train_dl.dataset.paths) stats['batch'] = { 'size': len(batch['text']),