actually save per-rank sampler states
This commit is contained in:
parent
74df2f5332
commit
39bc019142
|
@ -228,7 +228,7 @@ def train(
|
|||
print("Failed to set LR rate to:", rate, str(e))
|
||||
|
||||
if "export" in command:
|
||||
train_dl.dataset.save_state_dict(cfg.relpath / "train_dataset.pt")
|
||||
train_dl.dataset.save_state_dict(cfg.relpath / f"train_dataset.{global_rank()}.pt")
|
||||
engines.save_checkpoint()
|
||||
last_save_step = engines.global_step
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user