actually save per-rank sampler states

This commit is contained in:
mrq 2024-06-01 09:46:32 -05:00
parent 74df2f5332
commit 39bc019142

View File

@ -228,7 +228,7 @@ def train(
print("Failed to set LR rate to:", rate, str(e))
if "export" in command:
train_dl.dataset.save_state_dict(cfg.relpath / "train_dataset.pt")
train_dl.dataset.save_state_dict(cfg.relpath / f"train_dataset.{global_rank()}.pt")
engines.save_checkpoint()
last_save_step = engines.global_step