finally figured out a clean way to handle "resuming" the tqdm bar
This commit is contained in:
parent
35389481ee
commit
3dd31e74d1
|
@ -140,7 +140,7 @@ def load_engines(training=True, **model_kwargs):
|
||||||
"proj": "random",
|
"proj": "random",
|
||||||
"scale_type": "tensor" if is_mini else "channel",
|
"scale_type": "tensor" if is_mini else "channel",
|
||||||
"scale": 128 if is_mini else 1,
|
"scale": 128 if is_mini else 1,
|
||||||
"update_proj_gap": 200,
|
"update_proj_gap": 1,
|
||||||
"proj_type": "std",
|
"proj_type": "std",
|
||||||
})
|
})
|
||||||
elif cfg.hyperparameters.optimizer.lower() == "adagrad":
|
elif cfg.hyperparameters.optimizer.lower() == "adagrad":
|
||||||
|
|
|
@ -277,7 +277,7 @@ class Apollo(Optimizer):
|
||||||
proj: str = "random",
|
proj: str = "random",
|
||||||
scale_type: str = "channel",
|
scale_type: str = "channel",
|
||||||
scale: int = 1,
|
scale: int = 1,
|
||||||
update_proj_gap: int = 200,
|
update_proj_gap: int = 1,
|
||||||
proj_type: str = "std",
|
proj_type: str = "std",
|
||||||
):
|
):
|
||||||
if lr < 0.0:
|
if lr < 0.0:
|
||||||
|
@ -288,6 +288,7 @@ class Apollo(Optimizer):
|
||||||
raise ValueError(f"Invalid beta parameter: {betas[1]} - should be in [0.0, 1.0)")
|
raise ValueError(f"Invalid beta parameter: {betas[1]} - should be in [0.0, 1.0)")
|
||||||
if not 0.0 <= eps:
|
if not 0.0 <= eps:
|
||||||
raise ValueError(f"Invalid epsilon value: {eps} - should be >= 0.0")
|
raise ValueError(f"Invalid epsilon value: {eps} - should be >= 0.0")
|
||||||
|
|
||||||
defaults = {
|
defaults = {
|
||||||
"lr": lr,
|
"lr": lr,
|
||||||
"betas": betas,
|
"betas": betas,
|
||||||
|
|
|
@ -102,32 +102,16 @@ def _non_blocking_input():
|
||||||
|
|
||||||
|
|
||||||
def _make_infinite_epochs(dl):
|
def _make_infinite_epochs(dl):
|
||||||
start = dl.dataset.index()
|
|
||||||
total = dl.dataset.batches()
|
|
||||||
manual_update = False
|
|
||||||
|
|
||||||
if total == 0:
|
|
||||||
raise Exception("Empty dataset")
|
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
if dl.dataset.index() == 0:
|
if dl.dataset.index() == 0:
|
||||||
_logger.info("New epoch starts.")
|
_logger.info("New epoch starts.")
|
||||||
|
|
||||||
with tqdm(dl, "Epoch progress", dynamic_ncols=True, disable=not is_global_leader()) as pbar:
|
total = dl.dataset.batches() - dl.dataset.index()
|
||||||
if start:
|
if total <= 0:
|
||||||
pbar.update(start)
|
raise Exception("Empty dataset")
|
||||||
start = 0
|
|
||||||
"""
|
|
||||||
if start:
|
|
||||||
pbar.n = start
|
|
||||||
start = 0
|
|
||||||
manual_update = True
|
|
||||||
# for some reason this is required
|
|
||||||
if manual_update:
|
|
||||||
pbar.n += 1
|
|
||||||
"""
|
|
||||||
yield from pbar
|
|
||||||
|
|
||||||
|
with tqdm(dl, "Epoch progress", dynamic_ncols=True, disable=not is_global_leader(), total=total) as pbar:
|
||||||
|
yield from pbar
|
||||||
|
|
||||||
@local_leader_only(default=None)
|
@local_leader_only(default=None)
|
||||||
def logger(data):
|
def logger(data):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user