tweaks for the local engine orchestrator (that I never caught since I always used the deepspeed backend)

This commit is contained in:
mrq 2024-12-12 13:37:38 -06:00
parent 9a62e3b824
commit 6b237ae5e3

View File

@ -66,10 +66,11 @@ class Engine():
self.optimizer = kwargs['optimizer'] if 'optimizer' in kwargs else None
self.lr_scheduler = kwargs['lr_scheduler'] if 'lr_scheduler' in kwargs else None
self.global_steps = kwargs.pop("global_steps", 0)
self.micro_steps = kwargs.pop("micro_steps", 0)
self.global_samples = kwargs.pop("global_samples", 0)
self.tokens_processed = kwargs.pop("tokens_processed", 0)
stats = kwargs.pop("stats", {})
self.global_steps = stats.pop("global_step", 0)
self.micro_steps = stats.pop("micro_step", 0)
self.global_samples = stats.pop("global_samples", 0)
self.tokens_processed = stats.pop("tokens_processed", 0)
self._frozen_params = set()
@ -203,7 +204,7 @@ class Engine():
if load_lr_scheduler_states:
self.lr_scheduler.load_state_dict(state['lr_scheduler']) #, device=cfg.device)
if 'lora' in state:
if 'lora' in state and state['lora'] is not None:
lora_load_state_dict( self.module, state['lora'] )
def eval(self):
@ -534,7 +535,11 @@ class Engines(dict[str, Engine]):
raise RuntimeError("Out of memory during forward pass!")
"""
# no results are returned when a nan is encountered, so catch it here too
if res is None:
self.max_nan_losses = self.max_nan_losses - 1
if self.max_nan_losses < 0:
raise RuntimeError("Too many NaN losses detected.")
continue
loss, engine_stats = res
@ -543,7 +548,7 @@ class Engines(dict[str, Engine]):
if not cfg.trainer.check_for_oom:
engine.backward(loss)
else:
# to-do: properly handle when one GPU throws an OOM because it just halts
# to-do: properly handle when one GPU throws an OOM because it just halts despite doing a gather/reduce
try:
engine.backward(loss)
except RuntimeError as e: