fixed training tqdm being stubborn

This commit is contained in:
mrq 2024-11-23 09:45:23 -06:00
parent 41d7c30ea5
commit dcaf38b359
3 changed files with 21 additions and 3 deletions

View File

@ -1397,6 +1397,11 @@ class Dataset(_Dataset):
def index(self): def index(self):
return (self.sampler.index() if self.sampler is not None else -1) // self.batch_size return (self.sampler.index() if self.sampler is not None else -1) // self.batch_size
def batches(self):
if isinstance(self.sampler, BatchedOrderedSampler):
return len(self.sampler)
return len(self.sampler if self.sampler is not None else self) // self.batch_size
def __len__(self): def __len__(self):
if self.sampler_type == "group": if self.sampler_type == "group":
return min(len(self.spkr_groups), self._head or len(self.spkr_groups)) return min(len(self.spkr_groups), self._head or len(self.spkr_groups))

View File

@ -22,6 +22,13 @@ AVAILABLE_ATTENTIONS = []
LN_2 = 0.69314718056 LN_2 = 0.69314718056
try:
from torch.nn.attention.flex_attention import flex_attention, create_block_mask
AVAILABLE_ATTENTIONS.append("flex")
except Exception as e:
_logger.warning(f"Error while querying for `flexattention` support: {str(e)}")
try: try:
from transformers.utils import is_flash_attn_2_available from transformers.utils import is_flash_attn_2_available

View File

@ -102,12 +102,18 @@ def _non_blocking_input():
def _make_infinite_epochs(dl): def _make_infinite_epochs(dl):
start = dl.dataset.index()
total = dl.dataset.batches()
while True: while True:
if dl.dataset.index() == 0: if dl.dataset.index() == 0:
_logger.info("New epoch starts.") _logger.info("New epoch starts.")
#yield from tqdm(dl, "Epoch progress", dynamic_ncols=True, disable=not is_global_leader())
# this number may jump from the dataloader sampling before the actual training step happens with tqdm(dl, "Epoch progress", dynamic_ncols=True, disable=not is_global_leader()) as pbar:
yield from tqdm(dl, "Epoch progress", dynamic_ncols=True, disable=not is_global_leader(), initial=dl.dataset.index()) if start:
pbar.n = start
start = 0
yield from pbar
@local_leader_only(default=None) @local_leader_only(default=None)