diff --git a/main.py b/main.py index cefa8e3..315f6e6 100755 --- a/main.py +++ b/main.py @@ -1,13 +1,10 @@ import webui as mrq -if 'XDG_CACHE_HOME' not in os.environ: - os.environ['XDG_CACHE_HOME'] = os.path.realpath('./models/') - if 'TORTOISE_MODELS_DIR' not in os.environ: - os.environ['TORTOISE_MODELS_DIR'] = os.path.realpath('./models/tortoise/') + os.environ['TORTOISE_MODELS_DIR'] = os.path.realpath(os.path.join(os.getcwd(), './models/tortoise/')) if 'TRANSFORMERS_CACHE' not in os.environ: - os.environ['TRANSFORMERS_CACHE'] = os.path.realpath('./models/transformers/') + os.environ['TRANSFORMERS_CACHE'] = os.path.realpath(os.path.join(os.getcwd(), './models/transformers/')) if __name__ == "__main__": mrq.args = mrq.setup_args() diff --git a/tortoise/api.py b/tortoise/api.py index cc006a9..3806cb4 100755 --- a/tortoise/api.py +++ b/tortoise/api.py @@ -29,10 +29,8 @@ from tortoise.utils.wav2vec_alignment import Wav2VecAlignment from tortoise.utils.device import get_device, get_device_name, get_device_batch_size pbar = None - STOP_SIGNAL = False - -MODELS_DIR = os.environ.get('TORTOISE_MODELS_DIR', os.path.realpath('./models/tortoise/')) +MODELS_DIR = os.environ.get('TORTOISE_MODELS_DIR', os.path.realpath(os.path.join(os.getcwd(), './models/tortoise/'))) MODELS = { 'autoregressive.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/autoregressive.pth', 'classifier.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/classifier.pth',