finally figured out a clean way to handle "resuming" the tqdm bar

This commit is contained in:
mrq 2024-12-14 18:44:43 -06:00
parent 35389481ee
commit 3dd31e74d1
3 changed files with 8 additions and 23 deletions

View File

@ -140,7 +140,7 @@ def load_engines(training=True, **model_kwargs):
"proj": "random", "proj": "random",
"scale_type": "tensor" if is_mini else "channel", "scale_type": "tensor" if is_mini else "channel",
"scale": 128 if is_mini else 1, "scale": 128 if is_mini else 1,
"update_proj_gap": 200, "update_proj_gap": 1,
"proj_type": "std", "proj_type": "std",
}) })
elif cfg.hyperparameters.optimizer.lower() == "adagrad": elif cfg.hyperparameters.optimizer.lower() == "adagrad":

View File

@ -277,7 +277,7 @@ class Apollo(Optimizer):
proj: str = "random", proj: str = "random",
scale_type: str = "channel", scale_type: str = "channel",
scale: int = 1, scale: int = 1,
update_proj_gap: int = 200, update_proj_gap: int = 1,
proj_type: str = "std", proj_type: str = "std",
): ):
if lr < 0.0: if lr < 0.0:
@ -288,6 +288,7 @@ class Apollo(Optimizer):
raise ValueError(f"Invalid beta parameter: {betas[1]} - should be in [0.0, 1.0)") raise ValueError(f"Invalid beta parameter: {betas[1]} - should be in [0.0, 1.0)")
if not 0.0 <= eps: if not 0.0 <= eps:
raise ValueError(f"Invalid epsilon value: {eps} - should be >= 0.0") raise ValueError(f"Invalid epsilon value: {eps} - should be >= 0.0")
defaults = { defaults = {
"lr": lr, "lr": lr,
"betas": betas, "betas": betas,

View File

@ -102,32 +102,16 @@ def _non_blocking_input():
def _make_infinite_epochs(dl): def _make_infinite_epochs(dl):
start = dl.dataset.index()
total = dl.dataset.batches()
manual_update = False
if total == 0:
raise Exception("Empty dataset")
while True: while True:
if dl.dataset.index() == 0: if dl.dataset.index() == 0:
_logger.info("New epoch starts.") _logger.info("New epoch starts.")
with tqdm(dl, "Epoch progress", dynamic_ncols=True, disable=not is_global_leader()) as pbar: total = dl.dataset.batches() - dl.dataset.index()
if start: if total <= 0:
pbar.update(start) raise Exception("Empty dataset")
start = 0
"""
if start:
pbar.n = start
start = 0
manual_update = True
# for some reason this is required
if manual_update:
pbar.n += 1
"""
yield from pbar
with tqdm(dl, "Epoch progress", dynamic_ncols=True, disable=not is_global_leader(), total=total) as pbar:
yield from pbar
@local_leader_only(default=None) @local_leader_only(default=None)
def logger(data): def logger(data):