added pruning of old checkpoints if specified (cfg.trainer.keep_last_checkpoints)
This commit is contained in:
parent
44c08d828e
commit
d7152fc7b9
|
@ -130,11 +130,7 @@ Some additional flags you can pass are:
|
||||||
## To-Do
|
## To-Do
|
||||||
|
|
||||||
* properly pass in `modules` names to `weight_quantization` and `activation_quantization`.
|
* properly pass in `modules` names to `weight_quantization` and `activation_quantization`.
|
||||||
|
|
||||||
* fix `quit` hanging when using distributed training.
|
|
||||||
|
|
||||||
* train and release a model.
|
* train and release a model.
|
||||||
|
|
||||||
* extend to multiple languages (VALL-E X) and extend to SpeechX features.
|
* extend to multiple languages (VALL-E X) and extend to SpeechX features.
|
||||||
|
|
||||||
## Notice
|
## Notice
|
||||||
|
|
|
@ -368,6 +368,8 @@ class Trainer:
|
||||||
save_on_quit: bool = True
|
save_on_quit: bool = True
|
||||||
save_frequency: int = 100
|
save_frequency: int = 100
|
||||||
|
|
||||||
|
keep_last_checkpoints: int = 0
|
||||||
|
|
||||||
load_state_dict: bool = False
|
load_state_dict: bool = False
|
||||||
load_states: bool = True
|
load_states: bool = True
|
||||||
strict_loading: bool = True
|
strict_loading: bool = True
|
||||||
|
|
|
@ -218,7 +218,16 @@ class Engines(dict[str, Engine]):
|
||||||
|
|
||||||
cfg.ckpt_dir.mkdir(parents=True, exist_ok=True)
|
cfg.ckpt_dir.mkdir(parents=True, exist_ok=True)
|
||||||
for name, engine in self.items():
|
for name, engine in self.items():
|
||||||
engine.save_checkpoint(cfg.ckpt_dir / name, tag=tag)
|
save_dir = cfg.ckpt_dir / name
|
||||||
|
engine.save_checkpoint(save_dir, tag=tag)
|
||||||
|
if cfg.trainer.keep_last_checkpoints > 0:
|
||||||
|
checkpoints = list(save_dir.rglob("*/"))
|
||||||
|
checkpoints.sort(key=lambda x: x.stat().st_mtime)
|
||||||
|
checkpoints = checkpoints[:-cfg.trainer.keep_last_checkpoints]
|
||||||
|
for d in checkpoints:
|
||||||
|
for p in d.iterdir():
|
||||||
|
p.unlink()
|
||||||
|
d.rmdir()
|
||||||
|
|
||||||
def load_checkpoint(self, tag=None):
|
def load_checkpoint(self, tag=None):
|
||||||
if not tag:
|
if not tag:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user