finally figured out a clean way to handle "resuming" the tqdm bar

This commit is contained in:
mrq 2024-12-14 18:44:43 -06:00
parent 35389481ee
commit 3dd31e74d1
3 changed files with 8 additions and 23 deletions

View File

@ -140,7 +140,7 @@ def load_engines(training=True, **model_kwargs):
"proj": "random",
"scale_type": "tensor" if is_mini else "channel",
"scale": 128 if is_mini else 1,
"update_proj_gap": 200,
"update_proj_gap": 1,
"proj_type": "std",
})
elif cfg.hyperparameters.optimizer.lower() == "adagrad":

View File

@ -277,7 +277,7 @@ class Apollo(Optimizer):
proj: str = "random",
scale_type: str = "channel",
scale: int = 1,
update_proj_gap: int = 200,
update_proj_gap: int = 1,
proj_type: str = "std",
):
if lr < 0.0:
@ -288,6 +288,7 @@ class Apollo(Optimizer):
raise ValueError(f"Invalid beta parameter: {betas[1]} - should be in [0.0, 1.0)")
if not 0.0 <= eps:
raise ValueError(f"Invalid epsilon value: {eps} - should be >= 0.0")
defaults = {
"lr": lr,
"betas": betas,

View File

@ -102,33 +102,17 @@ def _non_blocking_input():
def _make_infinite_epochs(dl):
start = dl.dataset.index()
total = dl.dataset.batches()
manual_update = False
if total == 0:
raise Exception("Empty dataset")
while True:
if dl.dataset.index() == 0:
_logger.info("New epoch starts.")
total = dl.dataset.batches() - dl.dataset.index()
if total <= 0:
raise Exception("Empty dataset")
with tqdm(dl, "Epoch progress", dynamic_ncols=True, disable=not is_global_leader()) as pbar:
if start:
pbar.update(start)
start = 0
"""
if start:
pbar.n = start
start = 0
manual_update = True
# for some reason this is required
if manual_update:
pbar.n += 1
"""
with tqdm(dl, "Epoch progress", dynamic_ncols=True, disable=not is_global_leader(), total=total) as pbar:
yield from pbar
@local_leader_only(default=None)
def logger(data):
return _logger.info(json.dumps(data, default=str))