diff --git a/scripts/setup.sh b/scripts/setup.sh index a423632..26b6fe0 100755 --- a/scripts/setup.sh +++ b/scripts/setup.sh @@ -6,6 +6,8 @@ pip3 install -e . mkdir -p ./training/valle/ckpt/ar+nar-retnet-8/ wget -P ./training/valle/ckpt/ar+nar-retnet-8/ "https://huggingface.co/ecker/vall-e/resolve/main/ckpt/ar%2Bnar-retnet-8/fp32.pth" wget -P ./training/valle/ "https://huggingface.co/ecker/vall-e/resolve/main/data.tar.gz" +wget -P ./training/valle/ "https://huggingface.co/ecker/vall-e/resolve/main/.cache.tar.gz" wget -P ./training/valle/ "https://huggingface.co/ecker/vall-e/raw/main/config.yaml" -tar -xzf ./training/valle/data.tar.gz -C "./training/valle/" \ No newline at end of file +tar -xzf ./training/valle/data.tar.gz -C "./training/valle/" data.h5 +tar -xzf ./training/valle/.cache.tar.gz -C "./training/valle/" \ No newline at end of file diff --git a/vall_e/utils/trainer.py b/vall_e/utils/trainer.py index 6487b91..830e8a0 100755 --- a/vall_e/utils/trainer.py +++ b/vall_e/utils/trainer.py @@ -76,6 +76,11 @@ def load_engines(): optimizer = None lr_scheduler = None + # automatically load from state dict if one is provided, but no DeepSpeed checkpoint is present + if not cfg.trainer.load_state_dict and cfg.trainer.backend == "deepspeed" and not (cfg.ckpt_dir / name / "latest").exists(): + print("DeepSpeed checkpoint missing, but weights found.") + cfg.trainer.load_state_dict = True + stats = None if cfg.trainer.load_state_dict or not model._cfg.training: load_path = cfg.ckpt_dir / name / "fp32.pth"