This commit is contained in:
mrq 2024-06-01 10:46:42 -05:00
parent 827cf632e7
commit 8cf176ab46
2 changed files with 3 additions and 3 deletions

View File

@ -748,7 +748,7 @@ def create_datasets():
train_dataset = Dataset( training=True )
val_dataset = Dataset( phone_symmap=train_dataset.phone_symmap, training=False )
train_state_path = cfg.relpath / f"train_dataset.{global_rank()}.pt"
train_state_path = cfg.relpath / f"sampler.rank{global_rank()}.pt"
if train_state_path.exists():
train_dataset.load_state_dict( train_state_path )

View File

@ -230,7 +230,7 @@ def train(
print("Failed to set LR rate to:", rate, str(e))
if "export" in command:
train_dl.dataset.save_state_dict(cfg.relpath / f"train_dataset.{global_rank()}.pt")
train_dl.dataset.save_state_dict(cfg.relpath / f"sampler.rank{global_rank()}.pt")
engines.save_checkpoint()
last_save_step = engines.global_step
@ -253,7 +253,7 @@ def train(
if engines.global_step != last_save_step:
if engines.global_step % save_ckpt_every == 0 or command in saving_commands:
train_dl.dataset.save_state_dict(cfg.relpath / "train_dataset.pt")
train_dl.dataset.save_state_dict(cfg.relpath / f"sampler.rank{global_rank()}.pt")
engines.save_checkpoint()
last_save_step = engines.global_step