diff --git a/README.md b/README.md index 3f59b29..69f153e 100755 --- a/README.md +++ b/README.md @@ -130,11 +130,7 @@ Some additional flags you can pass are: ## To-Do * properly pass in `modules` names to `weight_quantization` and `activation_quantization`. - -* fix `quit` hanging when using distributed training. - * train and release a model. - * extend to multiple languages (VALL-E X) and extend to SpeechX features. ## Notice diff --git a/vall_e/config.py b/vall_e/config.py index bfa2546..341e088 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -368,6 +368,8 @@ class Trainer: save_on_quit: bool = True save_frequency: int = 100 + keep_last_checkpoints: int = 0 + load_state_dict: bool = False load_states: bool = True strict_loading: bool = True diff --git a/vall_e/engines/base.py b/vall_e/engines/base.py index fdf4fae..edc288f 100755 --- a/vall_e/engines/base.py +++ b/vall_e/engines/base.py @@ -218,7 +218,16 @@ class Engines(dict[str, Engine]): cfg.ckpt_dir.mkdir(parents=True, exist_ok=True) for name, engine in self.items(): - engine.save_checkpoint(cfg.ckpt_dir / name, tag=tag) + save_dir = cfg.ckpt_dir / name + engine.save_checkpoint(save_dir, tag=tag) + if cfg.trainer.keep_last_checkpoints > 0: + checkpoints = list(save_dir.rglob("*/")) + checkpoints.sort(key=lambda x: x.stat().st_mtime) + checkpoints = checkpoints[:-cfg.trainer.keep_last_checkpoints] + for d in checkpoints: + for p in d.iterdir(): + p.unlink() + d.rmdir() def load_checkpoint(self, tag=None): if not tag: