From 3dd31e74d19cdea6c918086c42f419d165c45ca0 Mon Sep 17 00:00:00 2001 From: mrq Date: Sat, 14 Dec 2024 18:44:43 -0600 Subject: [PATCH] finally figured out a clean way to handle "resuming" the tqdm bar --- vall_e/engines/__init__.py | 2 +- vall_e/utils/ext/apollo.py | 3 ++- vall_e/utils/trainer.py | 26 +++++--------------------- 3 files changed, 8 insertions(+), 23 deletions(-) diff --git a/vall_e/engines/__init__.py b/vall_e/engines/__init__.py index 0c0d99d..089da0b 100755 --- a/vall_e/engines/__init__.py +++ b/vall_e/engines/__init__.py @@ -140,7 +140,7 @@ def load_engines(training=True, **model_kwargs): "proj": "random", "scale_type": "tensor" if is_mini else "channel", "scale": 128 if is_mini else 1, - "update_proj_gap": 200, + "update_proj_gap": 1, "proj_type": "std", }) elif cfg.hyperparameters.optimizer.lower() == "adagrad": diff --git a/vall_e/utils/ext/apollo.py b/vall_e/utils/ext/apollo.py index e029bca..9dfe8b5 100644 --- a/vall_e/utils/ext/apollo.py +++ b/vall_e/utils/ext/apollo.py @@ -277,7 +277,7 @@ class Apollo(Optimizer): proj: str = "random", scale_type: str = "channel", scale: int = 1, - update_proj_gap: int = 200, + update_proj_gap: int = 1, proj_type: str = "std", ): if lr < 0.0: @@ -288,6 +288,7 @@ class Apollo(Optimizer): raise ValueError(f"Invalid beta parameter: {betas[1]} - should be in [0.0, 1.0)") if not 0.0 <= eps: raise ValueError(f"Invalid epsilon value: {eps} - should be >= 0.0") + defaults = { "lr": lr, "betas": betas, diff --git a/vall_e/utils/trainer.py b/vall_e/utils/trainer.py index 1d0fa13..9c8be2c 100755 --- a/vall_e/utils/trainer.py +++ b/vall_e/utils/trainer.py @@ -102,33 +102,17 @@ def _non_blocking_input(): def _make_infinite_epochs(dl): - start = dl.dataset.index() - total = dl.dataset.batches() - manual_update = False - - if total == 0: - raise Exception("Empty dataset") - while True: if dl.dataset.index() == 0: _logger.info("New epoch starts.") + + total = dl.dataset.batches() - dl.dataset.index() + if total <= 0: + raise Exception("Empty dataset") - with tqdm(dl, "Epoch progress", dynamic_ncols=True, disable=not is_global_leader()) as pbar: - if start: - pbar.update(start) - start = 0 - """ - if start: - pbar.n = start - start = 0 - manual_update = True - # for some reason this is required - if manual_update: - pbar.n += 1 - """ + with tqdm(dl, "Epoch progress", dynamic_ncols=True, disable=not is_global_leader(), total=total) as pbar: yield from pbar - @local_leader_only(default=None) def logger(data): return _logger.info(json.dumps(data, default=str))