diff --git a/vall_e/engines/base.py b/vall_e/engines/base.py index 7b05385..7d63f56 100755 --- a/vall_e/engines/base.py +++ b/vall_e/engines/base.py @@ -446,7 +446,7 @@ class Engines(dict[str, Engine]): engine.tokens_processed = 0 # update the LR because for some god awful reason it gets overwritten when loading from a checkpoint but only when it's not using a scheduler - if cfg.hyperparameters.scheduler_type == "": + if cfg.hyperparameters.scheduler == "": self.set_lr(cfg.hyperparameters.learning_rate) self._update() diff --git a/vall_e/engines/deepspeed.py b/vall_e/engines/deepspeed.py index f07705b..85ec63b 100755 --- a/vall_e/engines/deepspeed.py +++ b/vall_e/engines/deepspeed.py @@ -47,7 +47,7 @@ class Engine(DeepSpeedEngine): } # kwargs['stats'] = None will return None when popped - maybe_stats = kwargs.get('stats', stats) + maybe_stats = kwargs.pop('stats', stats) if maybe_stats is not None: stats = maybe_stats