tweaks
This commit is contained in:
parent
0fbfb8bbe8
commit
64c67160a3
|
@ -132,8 +132,10 @@ def load_engines(training=True, **model_kwargs):
|
|||
params['d_coef'] = params['lr']
|
||||
params['lr'] = 1.0
|
||||
elif cfg.hyperparameters.optimizer.lower() in ["apollo","apollo-mini"]:
|
||||
"""
|
||||
if backend == "deepspeed":
|
||||
raise Exception("APOLLO currently does not play nicely with DeepSpeed.")
|
||||
"""
|
||||
|
||||
optimizer_class = ml.Apollo
|
||||
is_mini = cfg.hyperparameters.optimizer.lower() == "apollo-mini"
|
||||
|
|
|
@ -67,11 +67,12 @@ class Engine():
|
|||
self.lr_scheduler = kwargs['lr_scheduler'] if 'lr_scheduler' in kwargs else None
|
||||
|
||||
stats = kwargs.pop("stats", {})
|
||||
if stats is not None:
|
||||
self.global_steps = stats.pop("global_step", 0)
|
||||
self.micro_steps = stats.pop("micro_step", 0)
|
||||
self.global_samples = stats.pop("global_samples", 0)
|
||||
self.tokens_processed = stats.pop("tokens_processed", 0)
|
||||
if stats is None:
|
||||
stats = {}
|
||||
self.global_steps = stats.pop("global_step", 0)
|
||||
self.micro_steps = stats.pop("micro_step", 0)
|
||||
self.global_samples = stats.pop("global_samples", 0)
|
||||
self.tokens_processed = stats.pop("tokens_processed", 0)
|
||||
|
||||
self._frozen_params = set()
|
||||
|
||||
|
@ -186,7 +187,7 @@ class Engine():
|
|||
|
||||
if not load_path.exists():
|
||||
return
|
||||
|
||||
|
||||
state = torch_load(load_path, device=cfg.device)
|
||||
|
||||
self.global_steps = state['stats']['global_step'] if 'stats' in state else state['global_step']
|
||||
|
@ -542,8 +543,8 @@ class Engines(dict[str, Engine]):
|
|||
|
||||
# no results are returned when a nan is encountered, so catch it here too
|
||||
if res is None:
|
||||
self.max_nan_losses = self.max_nan_losses - 1
|
||||
if self.max_nan_losses < 0:
|
||||
engine.max_nan_losses = engine.max_nan_losses - 1
|
||||
if engine.max_nan_losses < 0:
|
||||
raise RuntimeError("Too many NaN losses detected.")
|
||||
continue
|
||||
|
||||
|
|
|
@ -114,6 +114,10 @@ def _make_infinite_epochs(dl):
|
|||
_logger.info("New epoch starts.")
|
||||
|
||||
with tqdm(dl, "Epoch progress", dynamic_ncols=True, disable=not is_global_leader()) as pbar:
|
||||
if start:
|
||||
pbar.update(start)
|
||||
start = 0
|
||||
"""
|
||||
if start:
|
||||
pbar.n = start
|
||||
start = 0
|
||||
|
@ -121,6 +125,7 @@ def _make_infinite_epochs(dl):
|
|||
# for some reason this is required
|
||||
if manual_update:
|
||||
pbar.n += 1
|
||||
"""
|
||||
yield from pbar
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user