just nuked about 9 hours of progress because I didn't make sure it pruned only on the global leader
This commit is contained in:
parent
d7152fc7b9
commit
b5f247aa11
|
@ -456,7 +456,7 @@ def create_datasets():
|
||||||
|
|
||||||
def create_train_val_dataloader():
|
def create_train_val_dataloader():
|
||||||
train_dataset, val_dataset = create_datasets()
|
train_dataset, val_dataset = create_datasets()
|
||||||
#train_dataset.sample_type = "speaker"
|
train_dataset.sample_type = cfg.dataset.sample_type #"speaker"
|
||||||
|
|
||||||
subtrain_dataset = copy.deepcopy(train_dataset)
|
subtrain_dataset = copy.deepcopy(train_dataset)
|
||||||
subtrain_dataset.head_(cfg.evaluation.size)
|
subtrain_dataset.head_(cfg.evaluation.size)
|
||||||
|
|
|
@ -28,7 +28,7 @@ def default_feeder(engine, batch):
|
||||||
|
|
||||||
from ..config import cfg
|
from ..config import cfg
|
||||||
from ..utils import dispatch_attribute, flatten_dict, gather_attribute, do_gc, to_device
|
from ..utils import dispatch_attribute, flatten_dict, gather_attribute, do_gc, to_device
|
||||||
from ..utils.distributed import init_distributed, distributed_initialized
|
from ..utils.distributed import init_distributed, distributed_initialized, is_global_leader
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
|
@ -220,7 +220,9 @@ class Engines(dict[str, Engine]):
|
||||||
for name, engine in self.items():
|
for name, engine in self.items():
|
||||||
save_dir = cfg.ckpt_dir / name
|
save_dir = cfg.ckpt_dir / name
|
||||||
engine.save_checkpoint(save_dir, tag=tag)
|
engine.save_checkpoint(save_dir, tag=tag)
|
||||||
if cfg.trainer.keep_last_checkpoints > 0:
|
|
||||||
|
# might be better to prune before saving for safety, but [:0] returns an empty list, but I could do [:-cfg.trainer.keep_last_checkpoints - 1 if cfg.trainer.keep_last_checkpoints > 1 else None]
|
||||||
|
if cfg.trainer.keep_last_checkpoints > 0 and is_global_leader():
|
||||||
checkpoints = list(save_dir.rglob("*/"))
|
checkpoints = list(save_dir.rglob("*/"))
|
||||||
checkpoints.sort(key=lambda x: x.stat().st_mtime)
|
checkpoints.sort(key=lambda x: x.stat().st_mtime)
|
||||||
checkpoints = checkpoints[:-cfg.trainer.keep_last_checkpoints]
|
checkpoints = checkpoints[:-cfg.trainer.keep_last_checkpoints]
|
||||||
|
|
Loading…
Reference in New Issue
Block a user