From dcaf38b3597b3df8277b1a8a402740df22ce2744 Mon Sep 17 00:00:00 2001 From: mrq Date: Sat, 23 Nov 2024 09:45:23 -0600 Subject: [PATCH] fixed training tqdm being stubborn --- vall_e/data.py | 5 +++++ vall_e/models/arch/llama.py | 7 +++++++ vall_e/utils/trainer.py | 12 +++++++++--- 3 files changed, 21 insertions(+), 3 deletions(-) diff --git a/vall_e/data.py b/vall_e/data.py index acb7849..fb7d020 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -1396,6 +1396,11 @@ class Dataset(_Dataset): def index(self): return (self.sampler.index() if self.sampler is not None else -1) // self.batch_size + + def batches(self): + if isinstance(self.sampler, BatchedOrderedSampler): + return len(self.sampler) + return len(self.sampler if self.sampler is not None else self) // self.batch_size def __len__(self): if self.sampler_type == "group": diff --git a/vall_e/models/arch/llama.py b/vall_e/models/arch/llama.py index 03093a7..208c1de 100644 --- a/vall_e/models/arch/llama.py +++ b/vall_e/models/arch/llama.py @@ -22,6 +22,13 @@ AVAILABLE_ATTENTIONS = [] LN_2 = 0.69314718056 +try: + from torch.nn.attention.flex_attention import flex_attention, create_block_mask + + AVAILABLE_ATTENTIONS.append("flex") +except Exception as e: + _logger.warning(f"Error while querying for `flexattention` support: {str(e)}") + try: from transformers.utils import is_flash_attn_2_available diff --git a/vall_e/utils/trainer.py b/vall_e/utils/trainer.py index bf9969c..2e708d6 100755 --- a/vall_e/utils/trainer.py +++ b/vall_e/utils/trainer.py @@ -102,12 +102,18 @@ def _non_blocking_input(): def _make_infinite_epochs(dl): + start = dl.dataset.index() + total = dl.dataset.batches() + while True: if dl.dataset.index() == 0: _logger.info("New epoch starts.") - #yield from tqdm(dl, "Epoch progress", dynamic_ncols=True, disable=not is_global_leader()) - # this number may jump from the dataloader sampling before the actual training step happens - yield from tqdm(dl, "Epoch progress", dynamic_ncols=True, disable=not is_global_leader(), initial=dl.dataset.index()) + + with tqdm(dl, "Epoch progress", dynamic_ncols=True, disable=not is_global_leader()) as pbar: + if start: + pbar.n = start + start = 0 + yield from pbar @local_leader_only(default=None)