actually save per-rank sampler states
This commit is contained in:
parent
74df2f5332
commit
39bc019142
|
@ -228,7 +228,7 @@ def train(
|
||||||
print("Failed to set LR rate to:", rate, str(e))
|
print("Failed to set LR rate to:", rate, str(e))
|
||||||
|
|
||||||
if "export" in command:
|
if "export" in command:
|
||||||
train_dl.dataset.save_state_dict(cfg.relpath / "train_dataset.pt")
|
train_dl.dataset.save_state_dict(cfg.relpath / f"train_dataset.{global_rank()}.pt")
|
||||||
engines.save_checkpoint()
|
engines.save_checkpoint()
|
||||||
last_save_step = engines.global_step
|
last_save_step = engines.global_step
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user