implicitly load checkpoint if deepspeed checkpoint not found, updated setup script to grab the diskcached dataloader things

This commit is contained in:
mrq 2023-10-06 10:02:45 -05:00
parent 82f02ae9b1
commit 3db7e7dea1
2 changed files with 8 additions and 1 deletions

View File

@ -6,6 +6,8 @@ pip3 install -e .
mkdir -p ./training/valle/ckpt/ar+nar-retnet-8/
wget -P ./training/valle/ckpt/ar+nar-retnet-8/ "https://huggingface.co/ecker/vall-e/resolve/main/ckpt/ar%2Bnar-retnet-8/fp32.pth"
wget -P ./training/valle/ "https://huggingface.co/ecker/vall-e/resolve/main/data.tar.gz"
wget -P ./training/valle/ "https://huggingface.co/ecker/vall-e/resolve/main/.cache.tar.gz"
wget -P ./training/valle/ "https://huggingface.co/ecker/vall-e/raw/main/config.yaml"
tar -xzf ./training/valle/data.tar.gz -C "./training/valle/"
tar -xzf ./training/valle/data.tar.gz -C "./training/valle/" data.h5
tar -xzf ./training/valle/.cache.tar.gz -C "./training/valle/"

View File

@ -76,6 +76,11 @@ def load_engines():
optimizer = None
lr_scheduler = None
# automatically load from state dict if one is provided, but no DeepSpeed checkpoint is present
if not cfg.trainer.load_state_dict and cfg.trainer.backend == "deepspeed" and not (cfg.ckpt_dir / name / "latest").exists():
print("DeepSpeed checkpoint missing, but weights found.")
cfg.trainer.load_state_dict = True
stats = None
if cfg.trainer.load_state_dict or not model._cfg.training:
load_path = cfg.ckpt_dir / name / "fp32.pth"