diff --git a/tortoise/api.py b/tortoise/api.py index fa915b4..5707b99 100644 --- a/tortoise/api.py +++ b/tortoise/api.py @@ -25,6 +25,7 @@ from tortoise.utils.wav2vec_alignment import Wav2VecAlignment pbar = None +MODELS_DIR = os.environ.get('TORTOISE_MODELS_DIR', '.models') def download_models(specific_models=None): """ @@ -40,7 +41,7 @@ def download_models(specific_models=None): 'rlg_auto.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/rlg_auto.pth', 'rlg_diffuser.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/rlg_diffuser.pth', } - os.makedirs('.models', exist_ok=True) + os.makedirs(MODELS_DIR, exist_ok=True) def show_progress(block_num, block_size, total_size): global pbar if pbar is None: @@ -56,10 +57,11 @@ def download_models(specific_models=None): for model_name, url in MODELS.items(): if specific_models is not None and model_name not in specific_models: continue - if os.path.exists(f'.models/{model_name}'): + model_path = os.path.join(MODELS_DIR, model_name) + if os.path.exists(model_path): continue print(f'Downloading {model_name} from {url}...') - request.urlretrieve(url, f'.models/{model_name}', show_progress) + request.urlretrieve(url, model_path, show_progress) print('Done.') @@ -154,7 +156,7 @@ def classify_audio_clip(clip): classifier = AudioMiniEncoderWithClassifierHead(2, spec_dim=1, embedding_dim=512, depth=5, downsample_factor=4, resnet_blocks=2, attn_blocks=4, num_attn_heads=4, base_channels=32, dropout=0, kernel_size=5, distribute_zero_label=False) - classifier.load_state_dict(torch.load('.models/classifier.pth', map_location=torch.device('cpu'))) + classifier.load_state_dict(torch.load(os.path.join(MODELS_DIR, 'classifier.pth'), map_location=torch.device('cpu'))) clip = clip.cpu().unsqueeze(0) results = F.softmax(classifier(clip), dim=-1) return results[0][0] @@ -181,7 +183,7 @@ class TextToSpeech: Main entry point into Tortoise. """ - def __init__(self, autoregressive_batch_size=None, models_dir='.models', enable_redaction=True): + def __init__(self, autoregressive_batch_size=None, models_dir=MODELS_DIR, enable_redaction=True): """ Constructor :param autoregressive_batch_size: Specifies how many samples to generate per batch. Lower this if you are seeing @@ -276,9 +278,9 @@ class TextToSpeech: # Lazy-load the RLG models. if self.rlg_auto is None: self.rlg_auto = RandomLatentConverter(1024).eval() - self.rlg_auto.load_state_dict(torch.load('.models/rlg_auto.pth', map_location=torch.device('cpu'))) + self.rlg_auto.load_state_dict(torch.load(os.path.join(MODELS_DIR, 'rlg_auto.pth'), map_location=torch.device('cpu'))) self.rlg_diffusion = RandomLatentConverter(2048).eval() - self.rlg_diffusion.load_state_dict(torch.load('.models/rlg_diffuser.pth', map_location=torch.device('cpu'))) + self.rlg_diffusion.load_state_dict(torch.load(os.path.join(MODELS_DIR, 'rlg_diffuser.pth'), map_location=torch.device('cpu'))) with torch.no_grad(): return self.rlg_auto(torch.tensor([0.0])), self.rlg_diffusion(torch.tensor([0.0])) diff --git a/tortoise/do_tts.py b/tortoise/do_tts.py index b74466c..091781f 100644 --- a/tortoise/do_tts.py +++ b/tortoise/do_tts.py @@ -3,8 +3,8 @@ import os import torchaudio -from api import TextToSpeech -from tortoise.utils.audio import load_audio, get_voices, load_voice +from api import TextToSpeech, MODELS_DIR +from utils.audio import load_voice if __name__ == '__main__': parser = argparse.ArgumentParser() @@ -17,7 +17,7 @@ if __name__ == '__main__': default=.5) parser.add_argument('--output_path', type=str, help='Where to store outputs.', default='results/') parser.add_argument('--model_dir', type=str, help='Where to find pretrained model checkpoints. Tortoise automatically downloads these to .models, so this' - 'should only be specified if you have custom checkpoints.', default='.models') + 'should only be specified if you have custom checkpoints.', default=MODELS_DIR) parser.add_argument('--candidates', type=int, help='How many output candidates to produce per-voice.', default=3) args = parser.parse_args() os.makedirs(args.output_path, exist_ok=True) diff --git a/tortoise/read.py b/tortoise/read.py index e81bd71..ac284cc 100644 --- a/tortoise/read.py +++ b/tortoise/read.py @@ -4,8 +4,8 @@ import os import torch import torchaudio -from api import TextToSpeech -from utils.audio import load_audio, get_voices, load_voices +from api import TextToSpeech, MODELS_DIR +from utils.audio import load_audio, load_voices from utils.text import split_and_recombine_text @@ -21,7 +21,7 @@ if __name__ == '__main__': help='How to balance vocal diversity with the quality/intelligibility of the spoken text. 0 means highly diverse voice (not recommended), 1 means maximize intellibility', default=.5) parser.add_argument('--model_dir', type=str, help='Where to find pretrained model checkpoints. Tortoise automatically downloads these to .models, so this' - 'should only be specified if you have custom checkpoints.', default='.models') + 'should only be specified if you have custom checkpoints.', default=MODELS_DIR) args = parser.parse_args() tts = TextToSpeech(models_dir=args.model_dir) diff --git a/tortoise/utils/audio.py b/tortoise/utils/audio.py index fda6380..6cdd496 100644 --- a/tortoise/utils/audio.py +++ b/tortoise/utils/audio.py @@ -82,21 +82,23 @@ def dynamic_range_decompression(x, C=1): return torch.exp(x) / C -def get_voices(): - subs = os.listdir('tortoise/voices') +def get_voices(extra_voice_dirs=[]): + dirs = ['tortoise/voices'] + extra_voice_dirs voices = {} - for sub in subs: - subj = os.path.join('tortoise/voices', sub) - if os.path.isdir(subj): - voices[sub] = list(glob(f'{subj}/*.wav')) + list(glob(f'{subj}/*.mp3')) + list(glob(f'{subj}/*.pth')) + for d in dirs: + subs = os.listdir(d) + for sub in subs: + subj = os.path.join(d, sub) + if os.path.isdir(subj): + voices[sub] = list(glob(f'{subj}/*.wav')) + list(glob(f'{subj}/*.mp3')) + list(glob(f'{subj}/*.pth')) return voices -def load_voice(voice): +def load_voice(voice, extra_voice_dirs=[]): if voice == 'random': return None, None - voices = get_voices() + voices = get_voices(extra_voice_dirs) paths = voices[voice] if len(paths) == 1 and paths[0].endswith('.pth'): return None, torch.load(paths[0]) @@ -108,14 +110,14 @@ def load_voice(voice): return conds, None -def load_voices(voices): +def load_voices(voices, extra_voice_dirs=[]): latents = [] clips = [] for voice in voices: if voice == 'random': print("Cannot combine a random voice with a non-random voice. Just using a random voice.") return None, None - clip, latent = load_voice(voice) + clip, latent = load_voice(voice, extra_voice_dirs) if latent is None: assert len(latents) == 0, "Can only combine raw audio voices or latent voices, not both. Do it yourself if you want this." clips.extend(clip)