From b5f247aa11b3c3685021204be336de997e258fdb Mon Sep 17 00:00:00 2001 From: mrq Date: Wed, 16 Aug 2023 23:37:52 -0500 Subject: [PATCH] just nuked about 9 hours of progress because I didn't make sure it pruned only on the global leader --- vall_e/data.py | 2 +- vall_e/engines/base.py | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/vall_e/data.py b/vall_e/data.py index f3f371b..c06daf6 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -456,7 +456,7 @@ def create_datasets(): def create_train_val_dataloader(): train_dataset, val_dataset = create_datasets() - #train_dataset.sample_type = "speaker" + train_dataset.sample_type = cfg.dataset.sample_type #"speaker" subtrain_dataset = copy.deepcopy(train_dataset) subtrain_dataset.head_(cfg.evaluation.size) diff --git a/vall_e/engines/base.py b/vall_e/engines/base.py index edc288f..abff0d1 100755 --- a/vall_e/engines/base.py +++ b/vall_e/engines/base.py @@ -28,7 +28,7 @@ def default_feeder(engine, batch): from ..config import cfg from ..utils import dispatch_attribute, flatten_dict, gather_attribute, do_gc, to_device -from ..utils.distributed import init_distributed, distributed_initialized +from ..utils.distributed import init_distributed, distributed_initialized, is_global_leader import logging import time @@ -220,7 +220,9 @@ class Engines(dict[str, Engine]): for name, engine in self.items(): save_dir = cfg.ckpt_dir / name engine.save_checkpoint(save_dir, tag=tag) - if cfg.trainer.keep_last_checkpoints > 0: + + # might be better to prune before saving for safety, but [:0] returns an empty list, but I could do [:-cfg.trainer.keep_last_checkpoints - 1 if cfg.trainer.keep_last_checkpoints > 1 else None] + if cfg.trainer.keep_last_checkpoints > 0 and is_global_leader(): checkpoints = list(save_dir.rglob("*/")) checkpoints.sort(key=lambda x: x.stat().st_mtime) checkpoints = checkpoints[:-cfg.trainer.keep_last_checkpoints]