From 6b237ae5e3c5100c75b0e3a044ef861e156eb4f4 Mon Sep 17 00:00:00 2001 From: mrq Date: Thu, 12 Dec 2024 13:37:38 -0600 Subject: [PATCH] tweaks for the local engine orchestrator (that I never caught since I always used the deepspeed backend) --- vall_e/engines/base.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/vall_e/engines/base.py b/vall_e/engines/base.py index 270b477..51d91f0 100755 --- a/vall_e/engines/base.py +++ b/vall_e/engines/base.py @@ -66,10 +66,11 @@ class Engine(): self.optimizer = kwargs['optimizer'] if 'optimizer' in kwargs else None self.lr_scheduler = kwargs['lr_scheduler'] if 'lr_scheduler' in kwargs else None - self.global_steps = kwargs.pop("global_steps", 0) - self.micro_steps = kwargs.pop("micro_steps", 0) - self.global_samples = kwargs.pop("global_samples", 0) - self.tokens_processed = kwargs.pop("tokens_processed", 0) + stats = kwargs.pop("stats", {}) + self.global_steps = stats.pop("global_step", 0) + self.micro_steps = stats.pop("micro_step", 0) + self.global_samples = stats.pop("global_samples", 0) + self.tokens_processed = stats.pop("tokens_processed", 0) self._frozen_params = set() @@ -203,7 +204,7 @@ class Engine(): if load_lr_scheduler_states: self.lr_scheduler.load_state_dict(state['lr_scheduler']) #, device=cfg.device) - if 'lora' in state: + if 'lora' in state and state['lora'] is not None: lora_load_state_dict( self.module, state['lora'] ) def eval(self): @@ -534,7 +535,11 @@ class Engines(dict[str, Engine]): raise RuntimeError("Out of memory during forward pass!") """ + # no results are returned when a nan is encountered, so catch it here too if res is None: + self.max_nan_losses = self.max_nan_losses - 1 + if self.max_nan_losses < 0: + raise RuntimeError("Too many NaN losses detected.") continue loss, engine_stats = res @@ -543,7 +548,7 @@ class Engines(dict[str, Engine]): if not cfg.trainer.check_for_oom: engine.backward(loss) else: - # to-do: properly handle when one GPU throws an OOM because it just halts + # to-do: properly handle when one GPU throws an OOM because it just halts despite doing a gather/reduce try: engine.backward(loss) except RuntimeError as e: