fixed training tqdm being stubborn

This commit is contained in:
mrq 2024-11-23 09:45:23 -06:00
parent 41d7c30ea5
commit dcaf38b359
3 changed files with 21 additions and 3 deletions

View File

@ -1396,6 +1396,11 @@ class Dataset(_Dataset):
def index(self):
return (self.sampler.index() if self.sampler is not None else -1) // self.batch_size
def batches(self):
if isinstance(self.sampler, BatchedOrderedSampler):
return len(self.sampler)
return len(self.sampler if self.sampler is not None else self) // self.batch_size
def __len__(self):
if self.sampler_type == "group":

View File

@ -22,6 +22,13 @@ AVAILABLE_ATTENTIONS = []
LN_2 = 0.69314718056
try:
from torch.nn.attention.flex_attention import flex_attention, create_block_mask
AVAILABLE_ATTENTIONS.append("flex")
except Exception as e:
_logger.warning(f"Error while querying for `flexattention` support: {str(e)}")
try:
from transformers.utils import is_flash_attn_2_available

View File

@ -102,12 +102,18 @@ def _non_blocking_input():
def _make_infinite_epochs(dl):
start = dl.dataset.index()
total = dl.dataset.batches()
while True:
if dl.dataset.index() == 0:
_logger.info("New epoch starts.")
#yield from tqdm(dl, "Epoch progress", dynamic_ncols=True, disable=not is_global_leader())
# this number may jump from the dataloader sampling before the actual training step happens
yield from tqdm(dl, "Epoch progress", dynamic_ncols=True, disable=not is_global_leader(), initial=dl.dataset.index())
with tqdm(dl, "Epoch progress", dynamic_ncols=True, disable=not is_global_leader()) as pbar:
if start:
pbar.n = start
start = 0
yield from pbar
@local_leader_only(default=None)