allow defining the default model name through env var, register nemo-larger in the model name list thing

This commit is contained in:
mrq 2025-05-21 16:50:59 -05:00
parent e46d7ef2cb
commit f12746b091

View File

@ -1,24 +1,26 @@
import os
import logging
import requests
import time
from tqdm import tqdm
from pathlib import Path
import time
_logger = logging.getLogger(__name__)
# to-do: implement automatically downloading model
DEFAULT_MODEL_NAME = os.environ.get("VALLE_DEFAULT_MODEL_NAME", "ar+nar-len-llama-8.sft")
DEFAULT_MODEL_DIR = Path(__file__).parent.parent.parent / 'data/models'
DEFAULT_MODEL_PATH = DEFAULT_MODEL_DIR / "ar+nar-len-llama-8.sft"
DEFAULT_MODEL_PATH = DEFAULT_MODEL_DIR / DEFAULT_MODEL_NAME
DEFAULT_MODEL_URLS = {
'ar+nar-len-llama-8.sft': 'https://huggingface.co/ecker/vall-e/resolve/main/models/ckpt/ar%2Bnar-len-llama-8/ckpt/fp32.sft',
'nemo-larger-44khz-llama-8.sft': 'https://huggingface.co/ecker/vall-e/resolve/main/models/ckpt/nemo-larger-44khz-llama-8/fp32.sft',
'wavlm_large_finetune.pth': 'https://huggingface.co/Dongchao/UniAudio/resolve/main/wavlm_large_finetune.pth',
}
if not DEFAULT_MODEL_PATH.exists() and Path("./data/models/ar+nar-len-llama-8.sft").exists():
if not DEFAULT_MODEL_PATH.exists() and Path(f"./data/models/{DEFAULT_MODEL_NAME}").exists():
DEFAULT_MODEL_DIR = Path('./data/models')
DEFAULT_MODEL_PATH = DEFAULT_MODEL_DIR / "ar+nar-len-llama-8.sft"
DEFAULT_MODEL_PATH = DEFAULT_MODEL_DIR / DEFAULT_MODEL_NAME
# kludge, probably better to use HF's model downloader function
# to-do: write to a temp file then copy so downloads can be interrupted