diff --git a/vall_e/engines/deepspeed.py b/vall_e/engines/deepspeed.py index d04badc..b20cb40 100755 --- a/vall_e/engines/deepspeed.py +++ b/vall_e/engines/deepspeed.py @@ -39,16 +39,17 @@ class Engine(DeepSpeedEngine): kwargs['config'] = cfg.trainer.deepspeed.ds_cfg kwargs['config_class'] = DeepSpeedConfig(kwargs['config']) - if "stats" in kwargs: - # stats COULD be = None - stats = kwargs.pop('stats') - if stats is None: - stats = { - "global_steps": 0, - "micro_steps": 0, - "global_samples": 0, - "tokens_processed": 0, - } + stats = { + "global_steps": 0, + "micro_steps": 0, + "global_samples": 0, + "tokens_processed": 0, + } + + # kwargs['stats'] = None will return None when popped + maybe_stats = kwargs.pop('stats', stats) + if maybe_stats is not None: + stats = maybe_stats super().__init__(None, *args, **kwargs) self._frozen_params = set()