fixed training tqdm being stubborn
This commit is contained in:
parent
41d7c30ea5
commit
dcaf38b359
|
@ -1397,6 +1397,11 @@ class Dataset(_Dataset):
|
|||
def index(self):
|
||||
return (self.sampler.index() if self.sampler is not None else -1) // self.batch_size
|
||||
|
||||
def batches(self):
|
||||
if isinstance(self.sampler, BatchedOrderedSampler):
|
||||
return len(self.sampler)
|
||||
return len(self.sampler if self.sampler is not None else self) // self.batch_size
|
||||
|
||||
def __len__(self):
|
||||
if self.sampler_type == "group":
|
||||
return min(len(self.spkr_groups), self._head or len(self.spkr_groups))
|
||||
|
|
|
@ -22,6 +22,13 @@ AVAILABLE_ATTENTIONS = []
|
|||
|
||||
LN_2 = 0.69314718056
|
||||
|
||||
try:
|
||||
from torch.nn.attention.flex_attention import flex_attention, create_block_mask
|
||||
|
||||
AVAILABLE_ATTENTIONS.append("flex")
|
||||
except Exception as e:
|
||||
_logger.warning(f"Error while querying for `flexattention` support: {str(e)}")
|
||||
|
||||
try:
|
||||
from transformers.utils import is_flash_attn_2_available
|
||||
|
||||
|
|
|
@ -102,12 +102,18 @@ def _non_blocking_input():
|
|||
|
||||
|
||||
def _make_infinite_epochs(dl):
|
||||
start = dl.dataset.index()
|
||||
total = dl.dataset.batches()
|
||||
|
||||
while True:
|
||||
if dl.dataset.index() == 0:
|
||||
_logger.info("New epoch starts.")
|
||||
#yield from tqdm(dl, "Epoch progress", dynamic_ncols=True, disable=not is_global_leader())
|
||||
# this number may jump from the dataloader sampling before the actual training step happens
|
||||
yield from tqdm(dl, "Epoch progress", dynamic_ncols=True, disable=not is_global_leader(), initial=dl.dataset.index())
|
||||
|
||||
with tqdm(dl, "Epoch progress", dynamic_ncols=True, disable=not is_global_leader()) as pbar:
|
||||
if start:
|
||||
pbar.n = start
|
||||
start = 0
|
||||
yield from pbar
|
||||
|
||||
|
||||
@local_leader_only(default=None)
|
||||
|
|
Loading…
Reference in New Issue
Block a user