fixed training tqdm being stubborn
This commit is contained in:
parent
41d7c30ea5
commit
dcaf38b359
|
@ -1397,6 +1397,11 @@ class Dataset(_Dataset):
|
||||||
def index(self):
|
def index(self):
|
||||||
return (self.sampler.index() if self.sampler is not None else -1) // self.batch_size
|
return (self.sampler.index() if self.sampler is not None else -1) // self.batch_size
|
||||||
|
|
||||||
|
def batches(self):
|
||||||
|
if isinstance(self.sampler, BatchedOrderedSampler):
|
||||||
|
return len(self.sampler)
|
||||||
|
return len(self.sampler if self.sampler is not None else self) // self.batch_size
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
if self.sampler_type == "group":
|
if self.sampler_type == "group":
|
||||||
return min(len(self.spkr_groups), self._head or len(self.spkr_groups))
|
return min(len(self.spkr_groups), self._head or len(self.spkr_groups))
|
||||||
|
|
|
@ -22,6 +22,13 @@ AVAILABLE_ATTENTIONS = []
|
||||||
|
|
||||||
LN_2 = 0.69314718056
|
LN_2 = 0.69314718056
|
||||||
|
|
||||||
|
try:
|
||||||
|
from torch.nn.attention.flex_attention import flex_attention, create_block_mask
|
||||||
|
|
||||||
|
AVAILABLE_ATTENTIONS.append("flex")
|
||||||
|
except Exception as e:
|
||||||
|
_logger.warning(f"Error while querying for `flexattention` support: {str(e)}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from transformers.utils import is_flash_attn_2_available
|
from transformers.utils import is_flash_attn_2_available
|
||||||
|
|
||||||
|
|
|
@ -102,12 +102,18 @@ def _non_blocking_input():
|
||||||
|
|
||||||
|
|
||||||
def _make_infinite_epochs(dl):
|
def _make_infinite_epochs(dl):
|
||||||
|
start = dl.dataset.index()
|
||||||
|
total = dl.dataset.batches()
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
if dl.dataset.index() == 0:
|
if dl.dataset.index() == 0:
|
||||||
_logger.info("New epoch starts.")
|
_logger.info("New epoch starts.")
|
||||||
#yield from tqdm(dl, "Epoch progress", dynamic_ncols=True, disable=not is_global_leader())
|
|
||||||
# this number may jump from the dataloader sampling before the actual training step happens
|
with tqdm(dl, "Epoch progress", dynamic_ncols=True, disable=not is_global_leader()) as pbar:
|
||||||
yield from tqdm(dl, "Epoch progress", dynamic_ncols=True, disable=not is_global_leader(), initial=dl.dataset.index())
|
if start:
|
||||||
|
pbar.n = start
|
||||||
|
start = 0
|
||||||
|
yield from pbar
|
||||||
|
|
||||||
|
|
||||||
@local_leader_only(default=None)
|
@local_leader_only(default=None)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user