tweaks
This commit is contained in:
parent
0fbfb8bbe8
commit
64c67160a3
|
@ -132,8 +132,10 @@ def load_engines(training=True, **model_kwargs):
|
||||||
params['d_coef'] = params['lr']
|
params['d_coef'] = params['lr']
|
||||||
params['lr'] = 1.0
|
params['lr'] = 1.0
|
||||||
elif cfg.hyperparameters.optimizer.lower() in ["apollo","apollo-mini"]:
|
elif cfg.hyperparameters.optimizer.lower() in ["apollo","apollo-mini"]:
|
||||||
|
"""
|
||||||
if backend == "deepspeed":
|
if backend == "deepspeed":
|
||||||
raise Exception("APOLLO currently does not play nicely with DeepSpeed.")
|
raise Exception("APOLLO currently does not play nicely with DeepSpeed.")
|
||||||
|
"""
|
||||||
|
|
||||||
optimizer_class = ml.Apollo
|
optimizer_class = ml.Apollo
|
||||||
is_mini = cfg.hyperparameters.optimizer.lower() == "apollo-mini"
|
is_mini = cfg.hyperparameters.optimizer.lower() == "apollo-mini"
|
||||||
|
|
|
@ -67,11 +67,12 @@ class Engine():
|
||||||
self.lr_scheduler = kwargs['lr_scheduler'] if 'lr_scheduler' in kwargs else None
|
self.lr_scheduler = kwargs['lr_scheduler'] if 'lr_scheduler' in kwargs else None
|
||||||
|
|
||||||
stats = kwargs.pop("stats", {})
|
stats = kwargs.pop("stats", {})
|
||||||
if stats is not None:
|
if stats is None:
|
||||||
self.global_steps = stats.pop("global_step", 0)
|
stats = {}
|
||||||
self.micro_steps = stats.pop("micro_step", 0)
|
self.global_steps = stats.pop("global_step", 0)
|
||||||
self.global_samples = stats.pop("global_samples", 0)
|
self.micro_steps = stats.pop("micro_step", 0)
|
||||||
self.tokens_processed = stats.pop("tokens_processed", 0)
|
self.global_samples = stats.pop("global_samples", 0)
|
||||||
|
self.tokens_processed = stats.pop("tokens_processed", 0)
|
||||||
|
|
||||||
self._frozen_params = set()
|
self._frozen_params = set()
|
||||||
|
|
||||||
|
@ -186,7 +187,7 @@ class Engine():
|
||||||
|
|
||||||
if not load_path.exists():
|
if not load_path.exists():
|
||||||
return
|
return
|
||||||
|
|
||||||
state = torch_load(load_path, device=cfg.device)
|
state = torch_load(load_path, device=cfg.device)
|
||||||
|
|
||||||
self.global_steps = state['stats']['global_step'] if 'stats' in state else state['global_step']
|
self.global_steps = state['stats']['global_step'] if 'stats' in state else state['global_step']
|
||||||
|
@ -542,8 +543,8 @@ class Engines(dict[str, Engine]):
|
||||||
|
|
||||||
# no results are returned when a nan is encountered, so catch it here too
|
# no results are returned when a nan is encountered, so catch it here too
|
||||||
if res is None:
|
if res is None:
|
||||||
self.max_nan_losses = self.max_nan_losses - 1
|
engine.max_nan_losses = engine.max_nan_losses - 1
|
||||||
if self.max_nan_losses < 0:
|
if engine.max_nan_losses < 0:
|
||||||
raise RuntimeError("Too many NaN losses detected.")
|
raise RuntimeError("Too many NaN losses detected.")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|
|
@ -114,6 +114,10 @@ def _make_infinite_epochs(dl):
|
||||||
_logger.info("New epoch starts.")
|
_logger.info("New epoch starts.")
|
||||||
|
|
||||||
with tqdm(dl, "Epoch progress", dynamic_ncols=True, disable=not is_global_leader()) as pbar:
|
with tqdm(dl, "Epoch progress", dynamic_ncols=True, disable=not is_global_leader()) as pbar:
|
||||||
|
if start:
|
||||||
|
pbar.update(start)
|
||||||
|
start = 0
|
||||||
|
"""
|
||||||
if start:
|
if start:
|
||||||
pbar.n = start
|
pbar.n = start
|
||||||
start = 0
|
start = 0
|
||||||
|
@ -121,6 +125,7 @@ def _make_infinite_epochs(dl):
|
||||||
# for some reason this is required
|
# for some reason this is required
|
||||||
if manual_update:
|
if manual_update:
|
||||||
pbar.n += 1
|
pbar.n += 1
|
||||||
|
"""
|
||||||
yield from pbar
|
yield from pbar
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user