diff --git a/vall_e/data.py b/vall_e/data.py index 5f9ab19..ab1ffdb 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -903,6 +903,7 @@ class Dataset(_Dataset): state_dict = torch_load(path) if "dataset_hash_key" in state_dict: if self.dataset_hash_key != state_dict["dataset_hash_key"]: + _logger.warning(f'Mismatched dataset hash key for {self.dataset_type} dataloader, ignoring loading of state dict.') return if self.sampler_type == "path": diff --git a/vall_e/utils/trainer.py b/vall_e/utils/trainer.py index cd96d21..58d1a6b 100755 --- a/vall_e/utils/trainer.py +++ b/vall_e/utils/trainer.py @@ -103,8 +103,10 @@ def _non_blocking_input(): def _make_infinite_epochs(dl): while True: - #_logger.info("New epoch starts.") - yield from tqdm(dl, "Epoch progress", dynamic_ncols=True, disable=not is_global_leader()) # , initial=dl.dataset.index(), total=len(dl.dataset)) # to-do: figure out why this number jumps + if dl_dataset.index() == 0: + _logger.info("New epoch starts.") + # this number may jump from the dataloader sampling before the actual training step happens + yield from tqdm(dl, "Epoch progress", dynamic_ncols=True, disable=not is_global_leader(), initial=dl.dataset.index(), total=len(dl.dataset)) @local_leader_only(default=None)