This commit is contained in:
mrq 2024-12-13 19:00:35 -06:00
parent 0fbfb8bbe8
commit 64c67160a3
3 changed files with 16 additions and 8 deletions

View File

@ -132,8 +132,10 @@ def load_engines(training=True, **model_kwargs):
params['d_coef'] = params['lr']
params['lr'] = 1.0
elif cfg.hyperparameters.optimizer.lower() in ["apollo","apollo-mini"]:
"""
if backend == "deepspeed":
raise Exception("APOLLO currently does not play nicely with DeepSpeed.")
"""
optimizer_class = ml.Apollo
is_mini = cfg.hyperparameters.optimizer.lower() == "apollo-mini"

View File

@ -67,11 +67,12 @@ class Engine():
self.lr_scheduler = kwargs['lr_scheduler'] if 'lr_scheduler' in kwargs else None
stats = kwargs.pop("stats", {})
if stats is not None:
self.global_steps = stats.pop("global_step", 0)
self.micro_steps = stats.pop("micro_step", 0)
self.global_samples = stats.pop("global_samples", 0)
self.tokens_processed = stats.pop("tokens_processed", 0)
if stats is None:
stats = {}
self.global_steps = stats.pop("global_step", 0)
self.micro_steps = stats.pop("micro_step", 0)
self.global_samples = stats.pop("global_samples", 0)
self.tokens_processed = stats.pop("tokens_processed", 0)
self._frozen_params = set()
@ -186,7 +187,7 @@ class Engine():
if not load_path.exists():
return
state = torch_load(load_path, device=cfg.device)
self.global_steps = state['stats']['global_step'] if 'stats' in state else state['global_step']
@ -542,8 +543,8 @@ class Engines(dict[str, Engine]):
# no results are returned when a nan is encountered, so catch it here too
if res is None:
self.max_nan_losses = self.max_nan_losses - 1
if self.max_nan_losses < 0:
engine.max_nan_losses = engine.max_nan_losses - 1
if engine.max_nan_losses < 0:
raise RuntimeError("Too many NaN losses detected.")
continue

View File

@ -114,6 +114,10 @@ def _make_infinite_epochs(dl):
_logger.info("New epoch starts.")
with tqdm(dl, "Epoch progress", dynamic_ncols=True, disable=not is_global_leader()) as pbar:
if start:
pbar.update(start)
start = 0
"""
if start:
pbar.n = start
start = 0
@ -121,6 +125,7 @@ def _make_infinite_epochs(dl):
# for some reason this is required
if manual_update:
pbar.n += 1
"""
yield from pbar