resume iteration step in tqdm trainer, warn to logger if the sampler state dict was invalidated

This commit is contained in:
mrq 2024-11-13 09:09:28 -06:00
parent 8286aa54c8
commit 976ee87f6f
2 changed files with 5 additions and 2 deletions

View File

@ -903,6 +903,7 @@ class Dataset(_Dataset):
state_dict = torch_load(path)
if "dataset_hash_key" in state_dict:
if self.dataset_hash_key != state_dict["dataset_hash_key"]:
_logger.warning(f'Mismatched dataset hash key for {self.dataset_type} dataloader, ignoring loading of state dict.')
return
if self.sampler_type == "path":

View File

@ -103,8 +103,10 @@ def _non_blocking_input():
def _make_infinite_epochs(dl):
while True:
#_logger.info("New epoch starts.")
yield from tqdm(dl, "Epoch progress", dynamic_ncols=True, disable=not is_global_leader()) # , initial=dl.dataset.index(), total=len(dl.dataset)) # to-do: figure out why this number jumps
if dl_dataset.index() == 0:
_logger.info("New epoch starts.")
# this number may jump from the dataloader sampling before the actual training step happens
yield from tqdm(dl, "Epoch progress", dynamic_ncols=True, disable=not is_global_leader(), initial=dl.dataset.index(), total=len(dl.dataset))
@local_leader_only(default=None)