tweaks for the local engine orchestrator (that I never caught since I always used the deepspeed backend)
This commit is contained in:
parent
9a62e3b824
commit
6b237ae5e3
|
@ -66,10 +66,11 @@ class Engine():
|
|||
self.optimizer = kwargs['optimizer'] if 'optimizer' in kwargs else None
|
||||
self.lr_scheduler = kwargs['lr_scheduler'] if 'lr_scheduler' in kwargs else None
|
||||
|
||||
self.global_steps = kwargs.pop("global_steps", 0)
|
||||
self.micro_steps = kwargs.pop("micro_steps", 0)
|
||||
self.global_samples = kwargs.pop("global_samples", 0)
|
||||
self.tokens_processed = kwargs.pop("tokens_processed", 0)
|
||||
stats = kwargs.pop("stats", {})
|
||||
self.global_steps = stats.pop("global_step", 0)
|
||||
self.micro_steps = stats.pop("micro_step", 0)
|
||||
self.global_samples = stats.pop("global_samples", 0)
|
||||
self.tokens_processed = stats.pop("tokens_processed", 0)
|
||||
|
||||
self._frozen_params = set()
|
||||
|
||||
|
@ -203,7 +204,7 @@ class Engine():
|
|||
if load_lr_scheduler_states:
|
||||
self.lr_scheduler.load_state_dict(state['lr_scheduler']) #, device=cfg.device)
|
||||
|
||||
if 'lora' in state:
|
||||
if 'lora' in state and state['lora'] is not None:
|
||||
lora_load_state_dict( self.module, state['lora'] )
|
||||
|
||||
def eval(self):
|
||||
|
@ -534,7 +535,11 @@ class Engines(dict[str, Engine]):
|
|||
raise RuntimeError("Out of memory during forward pass!")
|
||||
"""
|
||||
|
||||
# no results are returned when a nan is encountered, so catch it here too
|
||||
if res is None:
|
||||
self.max_nan_losses = self.max_nan_losses - 1
|
||||
if self.max_nan_losses < 0:
|
||||
raise RuntimeError("Too many NaN losses detected.")
|
||||
continue
|
||||
|
||||
loss, engine_stats = res
|
||||
|
@ -543,7 +548,7 @@ class Engines(dict[str, Engine]):
|
|||
if not cfg.trainer.check_for_oom:
|
||||
engine.backward(loss)
|
||||
else:
|
||||
# to-do: properly handle when one GPU throws an OOM because it just halts
|
||||
# to-do: properly handle when one GPU throws an OOM because it just halts despite doing a gather/reduce
|
||||
try:
|
||||
engine.backward(loss)
|
||||
except RuntimeError as e:
|
||||
|
|
Loading…
Reference in New Issue
Block a user