This commit is contained in:
mrq 2024-12-13 19:00:35 -06:00
parent 0fbfb8bbe8
commit 64c67160a3
3 changed files with 16 additions and 8 deletions

View File

@ -132,8 +132,10 @@ def load_engines(training=True, **model_kwargs):
params['d_coef'] = params['lr'] params['d_coef'] = params['lr']
params['lr'] = 1.0 params['lr'] = 1.0
elif cfg.hyperparameters.optimizer.lower() in ["apollo","apollo-mini"]: elif cfg.hyperparameters.optimizer.lower() in ["apollo","apollo-mini"]:
"""
if backend == "deepspeed": if backend == "deepspeed":
raise Exception("APOLLO currently does not play nicely with DeepSpeed.") raise Exception("APOLLO currently does not play nicely with DeepSpeed.")
"""
optimizer_class = ml.Apollo optimizer_class = ml.Apollo
is_mini = cfg.hyperparameters.optimizer.lower() == "apollo-mini" is_mini = cfg.hyperparameters.optimizer.lower() == "apollo-mini"

View File

@ -67,7 +67,8 @@ class Engine():
self.lr_scheduler = kwargs['lr_scheduler'] if 'lr_scheduler' in kwargs else None self.lr_scheduler = kwargs['lr_scheduler'] if 'lr_scheduler' in kwargs else None
stats = kwargs.pop("stats", {}) stats = kwargs.pop("stats", {})
if stats is not None: if stats is None:
stats = {}
self.global_steps = stats.pop("global_step", 0) self.global_steps = stats.pop("global_step", 0)
self.micro_steps = stats.pop("micro_step", 0) self.micro_steps = stats.pop("micro_step", 0)
self.global_samples = stats.pop("global_samples", 0) self.global_samples = stats.pop("global_samples", 0)
@ -542,8 +543,8 @@ class Engines(dict[str, Engine]):
# no results are returned when a nan is encountered, so catch it here too # no results are returned when a nan is encountered, so catch it here too
if res is None: if res is None:
self.max_nan_losses = self.max_nan_losses - 1 engine.max_nan_losses = engine.max_nan_losses - 1
if self.max_nan_losses < 0: if engine.max_nan_losses < 0:
raise RuntimeError("Too many NaN losses detected.") raise RuntimeError("Too many NaN losses detected.")
continue continue

View File

@ -114,6 +114,10 @@ def _make_infinite_epochs(dl):
_logger.info("New epoch starts.") _logger.info("New epoch starts.")
with tqdm(dl, "Epoch progress", dynamic_ncols=True, disable=not is_global_leader()) as pbar: with tqdm(dl, "Epoch progress", dynamic_ncols=True, disable=not is_global_leader()) as pbar:
if start:
pbar.update(start)
start = 0
"""
if start: if start:
pbar.n = start pbar.n = start
start = 0 start = 0
@ -121,6 +125,7 @@ def _make_infinite_epochs(dl):
# for some reason this is required # for some reason this is required
if manual_update: if manual_update:
pbar.n += 1 pbar.n += 1
"""
yield from pbar yield from pbar