diff --git a/vall_e/data.py b/vall_e/data.py index 8122e65..5a58f2f 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -748,7 +748,7 @@ def create_datasets(): train_dataset = Dataset( training=True ) val_dataset = Dataset( phone_symmap=train_dataset.phone_symmap, training=False ) - train_state_path = cfg.relpath / f"train_dataset.{global_rank()}.pt" + train_state_path = cfg.relpath / f"sampler.rank{global_rank()}.pt" if train_state_path.exists(): train_dataset.load_state_dict( train_state_path ) diff --git a/vall_e/utils/trainer.py b/vall_e/utils/trainer.py index 5197ce0..971bfe0 100755 --- a/vall_e/utils/trainer.py +++ b/vall_e/utils/trainer.py @@ -230,7 +230,7 @@ def train( print("Failed to set LR rate to:", rate, str(e)) if "export" in command: - train_dl.dataset.save_state_dict(cfg.relpath / f"train_dataset.{global_rank()}.pt") + train_dl.dataset.save_state_dict(cfg.relpath / f"sampler.rank{global_rank()}.pt") engines.save_checkpoint() last_save_step = engines.global_step @@ -253,7 +253,7 @@ def train( if engines.global_step != last_save_step: if engines.global_step % save_ckpt_every == 0 or command in saving_commands: - train_dl.dataset.save_state_dict(cfg.relpath / "train_dataset.pt") + train_dl.dataset.save_state_dict(cfg.relpath / f"sampler.rank{global_rank()}.pt") engines.save_checkpoint() last_save_step = engines.global_step