finally figured out a clean way to handle "resuming" the tqdm bar
This commit is contained in:
parent
35389481ee
commit
3dd31e74d1
|
@ -140,7 +140,7 @@ def load_engines(training=True, **model_kwargs):
|
|||
"proj": "random",
|
||||
"scale_type": "tensor" if is_mini else "channel",
|
||||
"scale": 128 if is_mini else 1,
|
||||
"update_proj_gap": 200,
|
||||
"update_proj_gap": 1,
|
||||
"proj_type": "std",
|
||||
})
|
||||
elif cfg.hyperparameters.optimizer.lower() == "adagrad":
|
||||
|
|
|
@ -277,7 +277,7 @@ class Apollo(Optimizer):
|
|||
proj: str = "random",
|
||||
scale_type: str = "channel",
|
||||
scale: int = 1,
|
||||
update_proj_gap: int = 200,
|
||||
update_proj_gap: int = 1,
|
||||
proj_type: str = "std",
|
||||
):
|
||||
if lr < 0.0:
|
||||
|
@ -288,6 +288,7 @@ class Apollo(Optimizer):
|
|||
raise ValueError(f"Invalid beta parameter: {betas[1]} - should be in [0.0, 1.0)")
|
||||
if not 0.0 <= eps:
|
||||
raise ValueError(f"Invalid epsilon value: {eps} - should be >= 0.0")
|
||||
|
||||
defaults = {
|
||||
"lr": lr,
|
||||
"betas": betas,
|
||||
|
|
|
@ -102,32 +102,16 @@ def _non_blocking_input():
|
|||
|
||||
|
||||
def _make_infinite_epochs(dl):
|
||||
start = dl.dataset.index()
|
||||
total = dl.dataset.batches()
|
||||
manual_update = False
|
||||
|
||||
if total == 0:
|
||||
raise Exception("Empty dataset")
|
||||
|
||||
while True:
|
||||
if dl.dataset.index() == 0:
|
||||
_logger.info("New epoch starts.")
|
||||
|
||||
with tqdm(dl, "Epoch progress", dynamic_ncols=True, disable=not is_global_leader()) as pbar:
|
||||
if start:
|
||||
pbar.update(start)
|
||||
start = 0
|
||||
"""
|
||||
if start:
|
||||
pbar.n = start
|
||||
start = 0
|
||||
manual_update = True
|
||||
# for some reason this is required
|
||||
if manual_update:
|
||||
pbar.n += 1
|
||||
"""
|
||||
yield from pbar
|
||||
total = dl.dataset.batches() - dl.dataset.index()
|
||||
if total <= 0:
|
||||
raise Exception("Empty dataset")
|
||||
|
||||
with tqdm(dl, "Epoch progress", dynamic_ncols=True, disable=not is_global_leader(), total=total) as pbar:
|
||||
yield from pbar
|
||||
|
||||
@local_leader_only(default=None)
|
||||
def logger(data):
|
||||
|
|
Loading…
Reference in New Issue
Block a user