diff --git a/vall_e/engines/__init__.py b/vall_e/engines/__init__.py index f404254..4caf299 100755 --- a/vall_e/engines/__init__.py +++ b/vall_e/engines/__init__.py @@ -132,8 +132,10 @@ def load_engines(training=True, **model_kwargs): params['d_coef'] = params['lr'] params['lr'] = 1.0 elif cfg.hyperparameters.optimizer.lower() in ["apollo","apollo-mini"]: + """ if backend == "deepspeed": raise Exception("APOLLO currently does not play nicely with DeepSpeed.") + """ optimizer_class = ml.Apollo is_mini = cfg.hyperparameters.optimizer.lower() == "apollo-mini" diff --git a/vall_e/engines/base.py b/vall_e/engines/base.py index 61fb4e1..053eb0f 100755 --- a/vall_e/engines/base.py +++ b/vall_e/engines/base.py @@ -67,11 +67,12 @@ class Engine(): self.lr_scheduler = kwargs['lr_scheduler'] if 'lr_scheduler' in kwargs else None stats = kwargs.pop("stats", {}) - if stats is not None: - self.global_steps = stats.pop("global_step", 0) - self.micro_steps = stats.pop("micro_step", 0) - self.global_samples = stats.pop("global_samples", 0) - self.tokens_processed = stats.pop("tokens_processed", 0) + if stats is None: + stats = {} + self.global_steps = stats.pop("global_step", 0) + self.micro_steps = stats.pop("micro_step", 0) + self.global_samples = stats.pop("global_samples", 0) + self.tokens_processed = stats.pop("tokens_processed", 0) self._frozen_params = set() @@ -186,7 +187,7 @@ class Engine(): if not load_path.exists(): return - + state = torch_load(load_path, device=cfg.device) self.global_steps = state['stats']['global_step'] if 'stats' in state else state['global_step'] @@ -542,8 +543,8 @@ class Engines(dict[str, Engine]): # no results are returned when a nan is encountered, so catch it here too if res is None: - self.max_nan_losses = self.max_nan_losses - 1 - if self.max_nan_losses < 0: + engine.max_nan_losses = engine.max_nan_losses - 1 + if engine.max_nan_losses < 0: raise RuntimeError("Too many NaN losses detected.") continue diff --git a/vall_e/utils/trainer.py b/vall_e/utils/trainer.py index feb52d4..1d0fa13 100755 --- a/vall_e/utils/trainer.py +++ b/vall_e/utils/trainer.py @@ -114,6 +114,10 @@ def _make_infinite_epochs(dl): _logger.info("New epoch starts.") with tqdm(dl, "Epoch progress", dynamic_ncols=True, disable=not is_global_leader()) as pbar: + if start: + pbar.update(start) + start = 0 + """ if start: pbar.n = start start = 0 @@ -121,6 +125,7 @@ def _make_infinite_epochs(dl): # for some reason this is required if manual_update: pbar.n += 1 + """ yield from pbar