diff --git a/vall_e/models/__init__.py b/vall_e/models/__init__.py index 7b33b10..fdc99ed 100644 --- a/vall_e/models/__init__.py +++ b/vall_e/models/__init__.py @@ -1,24 +1,26 @@ +import os import logging - import requests +import time + from tqdm import tqdm from pathlib import Path -import time - _logger = logging.getLogger(__name__) # to-do: implement automatically downloading model +DEFAULT_MODEL_NAME = os.environ.get("VALLE_DEFAULT_MODEL_NAME", "ar+nar-len-llama-8.sft") DEFAULT_MODEL_DIR = Path(__file__).parent.parent.parent / 'data/models' -DEFAULT_MODEL_PATH = DEFAULT_MODEL_DIR / "ar+nar-len-llama-8.sft" +DEFAULT_MODEL_PATH = DEFAULT_MODEL_DIR / DEFAULT_MODEL_NAME DEFAULT_MODEL_URLS = { 'ar+nar-len-llama-8.sft': 'https://huggingface.co/ecker/vall-e/resolve/main/models/ckpt/ar%2Bnar-len-llama-8/ckpt/fp32.sft', + 'nemo-larger-44khz-llama-8.sft': 'https://huggingface.co/ecker/vall-e/resolve/main/models/ckpt/nemo-larger-44khz-llama-8/fp32.sft', 'wavlm_large_finetune.pth': 'https://huggingface.co/Dongchao/UniAudio/resolve/main/wavlm_large_finetune.pth', } -if not DEFAULT_MODEL_PATH.exists() and Path("./data/models/ar+nar-len-llama-8.sft").exists(): +if not DEFAULT_MODEL_PATH.exists() and Path(f"./data/models/{DEFAULT_MODEL_NAME}").exists(): DEFAULT_MODEL_DIR = Path('./data/models') - DEFAULT_MODEL_PATH = DEFAULT_MODEL_DIR / "ar+nar-len-llama-8.sft" + DEFAULT_MODEL_PATH = DEFAULT_MODEL_DIR / DEFAULT_MODEL_NAME # kludge, probably better to use HF's model downloader function # to-do: write to a temp file then copy so downloads can be interrupted