resume iteration step in tqdm trainer, warn to logger if the sampler state dict was invalidated
This commit is contained in:
parent
8286aa54c8
commit
976ee87f6f
|
@ -903,6 +903,7 @@ class Dataset(_Dataset):
|
|||
state_dict = torch_load(path)
|
||||
if "dataset_hash_key" in state_dict:
|
||||
if self.dataset_hash_key != state_dict["dataset_hash_key"]:
|
||||
_logger.warning(f'Mismatched dataset hash key for {self.dataset_type} dataloader, ignoring loading of state dict.')
|
||||
return
|
||||
|
||||
if self.sampler_type == "path":
|
||||
|
|
|
@ -103,8 +103,10 @@ def _non_blocking_input():
|
|||
|
||||
def _make_infinite_epochs(dl):
|
||||
while True:
|
||||
#_logger.info("New epoch starts.")
|
||||
yield from tqdm(dl, "Epoch progress", dynamic_ncols=True, disable=not is_global_leader()) # , initial=dl.dataset.index(), total=len(dl.dataset)) # to-do: figure out why this number jumps
|
||||
if dl_dataset.index() == 0:
|
||||
_logger.info("New epoch starts.")
|
||||
# this number may jump from the dataloader sampling before the actual training step happens
|
||||
yield from tqdm(dl, "Epoch progress", dynamic_ncols=True, disable=not is_global_leader(), initial=dl.dataset.index(), total=len(dl.dataset))
|
||||
|
||||
|
||||
@local_leader_only(default=None)
|
||||
|
|
Loading…
Reference in New Issue
Block a user