added pruning of old checkpoints if specified (cfg.trainer.keep_last_checkpoints)

This commit is contained in:
mrq 2023-08-16 20:12:12 -05:00
parent 44c08d828e
commit d7152fc7b9
3 changed files with 12 additions and 5 deletions

View File

@ -130,11 +130,7 @@ Some additional flags you can pass are:
## To-Do ## To-Do
* properly pass in `modules` names to `weight_quantization` and `activation_quantization`. * properly pass in `modules` names to `weight_quantization` and `activation_quantization`.
* fix `quit` hanging when using distributed training.
* train and release a model. * train and release a model.
* extend to multiple languages (VALL-E X) and extend to SpeechX features. * extend to multiple languages (VALL-E X) and extend to SpeechX features.
## Notice ## Notice

View File

@ -368,6 +368,8 @@ class Trainer:
save_on_quit: bool = True save_on_quit: bool = True
save_frequency: int = 100 save_frequency: int = 100
keep_last_checkpoints: int = 0
load_state_dict: bool = False load_state_dict: bool = False
load_states: bool = True load_states: bool = True
strict_loading: bool = True strict_loading: bool = True

View File

@ -218,7 +218,16 @@ class Engines(dict[str, Engine]):
cfg.ckpt_dir.mkdir(parents=True, exist_ok=True) cfg.ckpt_dir.mkdir(parents=True, exist_ok=True)
for name, engine in self.items(): for name, engine in self.items():
engine.save_checkpoint(cfg.ckpt_dir / name, tag=tag) save_dir = cfg.ckpt_dir / name
engine.save_checkpoint(save_dir, tag=tag)
if cfg.trainer.keep_last_checkpoints > 0:
checkpoints = list(save_dir.rglob("*/"))
checkpoints.sort(key=lambda x: x.stat().st_mtime)
checkpoints = checkpoints[:-cfg.trainer.keep_last_checkpoints]
for d in checkpoints:
for p in d.iterdir():
p.unlink()
d.rmdir()
def load_checkpoint(self, tag=None): def load_checkpoint(self, tag=None):
if not tag: if not tag: