ugh
This commit is contained in:
parent
827cf632e7
commit
8cf176ab46
|
@ -748,7 +748,7 @@ def create_datasets():
|
|||
train_dataset = Dataset( training=True )
|
||||
val_dataset = Dataset( phone_symmap=train_dataset.phone_symmap, training=False )
|
||||
|
||||
train_state_path = cfg.relpath / f"train_dataset.{global_rank()}.pt"
|
||||
train_state_path = cfg.relpath / f"sampler.rank{global_rank()}.pt"
|
||||
if train_state_path.exists():
|
||||
train_dataset.load_state_dict( train_state_path )
|
||||
|
||||
|
|
|
@ -230,7 +230,7 @@ def train(
|
|||
print("Failed to set LR rate to:", rate, str(e))
|
||||
|
||||
if "export" in command:
|
||||
train_dl.dataset.save_state_dict(cfg.relpath / f"train_dataset.{global_rank()}.pt")
|
||||
train_dl.dataset.save_state_dict(cfg.relpath / f"sampler.rank{global_rank()}.pt")
|
||||
engines.save_checkpoint()
|
||||
last_save_step = engines.global_step
|
||||
|
||||
|
@ -253,7 +253,7 @@ def train(
|
|||
|
||||
if engines.global_step != last_save_step:
|
||||
if engines.global_step % save_ckpt_every == 0 or command in saving_commands:
|
||||
train_dl.dataset.save_state_dict(cfg.relpath / "train_dataset.pt")
|
||||
train_dl.dataset.save_state_dict(cfg.relpath / f"sampler.rank{global_rank()}.pt")
|
||||
engines.save_checkpoint()
|
||||
last_save_step = engines.global_step
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user