diff --git a/vall_e/utils/trainer.py b/vall_e/utils/trainer.py index fd79dcb..b25eeb6 100755 --- a/vall_e/utils/trainer.py +++ b/vall_e/utils/trainer.py @@ -228,7 +228,7 @@ def train( print("Failed to set LR rate to:", rate, str(e)) if "export" in command: - train_dl.dataset.save_state_dict(cfg.relpath / "train_dataset.pt") + train_dl.dataset.save_state_dict(cfg.relpath / f"train_dataset.{global_rank()}.pt") engines.save_checkpoint() last_save_step = engines.global_step