Merge pull request #55 from jnordberg/models-dir

Make models dir configurable
This commit is contained in:
James Betker 2022-05-19 09:51:21 -06:00 committed by GitHub
commit 4641933d74
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 27 additions and 23 deletions

View File

@ -25,6 +25,7 @@ from tortoise.utils.wav2vec_alignment import Wav2VecAlignment
pbar = None pbar = None
MODELS_DIR = os.environ.get('TORTOISE_MODELS_DIR', '.models')
def download_models(specific_models=None): def download_models(specific_models=None):
""" """
@ -40,7 +41,7 @@ def download_models(specific_models=None):
'rlg_auto.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/rlg_auto.pth', 'rlg_auto.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/rlg_auto.pth',
'rlg_diffuser.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/rlg_diffuser.pth', 'rlg_diffuser.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/rlg_diffuser.pth',
} }
os.makedirs('.models', exist_ok=True) os.makedirs(MODELS_DIR, exist_ok=True)
def show_progress(block_num, block_size, total_size): def show_progress(block_num, block_size, total_size):
global pbar global pbar
if pbar is None: if pbar is None:
@ -56,10 +57,11 @@ def download_models(specific_models=None):
for model_name, url in MODELS.items(): for model_name, url in MODELS.items():
if specific_models is not None and model_name not in specific_models: if specific_models is not None and model_name not in specific_models:
continue continue
if os.path.exists(f'.models/{model_name}'): model_path = os.path.join(MODELS_DIR, model_name)
if os.path.exists(model_path):
continue continue
print(f'Downloading {model_name} from {url}...') print(f'Downloading {model_name} from {url}...')
request.urlretrieve(url, f'.models/{model_name}', show_progress) request.urlretrieve(url, model_path, show_progress)
print('Done.') print('Done.')
@ -154,7 +156,7 @@ def classify_audio_clip(clip):
classifier = AudioMiniEncoderWithClassifierHead(2, spec_dim=1, embedding_dim=512, depth=5, downsample_factor=4, classifier = AudioMiniEncoderWithClassifierHead(2, spec_dim=1, embedding_dim=512, depth=5, downsample_factor=4,
resnet_blocks=2, attn_blocks=4, num_attn_heads=4, base_channels=32, resnet_blocks=2, attn_blocks=4, num_attn_heads=4, base_channels=32,
dropout=0, kernel_size=5, distribute_zero_label=False) dropout=0, kernel_size=5, distribute_zero_label=False)
classifier.load_state_dict(torch.load('.models/classifier.pth', map_location=torch.device('cpu'))) classifier.load_state_dict(torch.load(os.path.join(MODELS_DIR, 'classifier.pth'), map_location=torch.device('cpu')))
clip = clip.cpu().unsqueeze(0) clip = clip.cpu().unsqueeze(0)
results = F.softmax(classifier(clip), dim=-1) results = F.softmax(classifier(clip), dim=-1)
return results[0][0] return results[0][0]
@ -181,7 +183,7 @@ class TextToSpeech:
Main entry point into Tortoise. Main entry point into Tortoise.
""" """
def __init__(self, autoregressive_batch_size=None, models_dir='.models', enable_redaction=True): def __init__(self, autoregressive_batch_size=None, models_dir=MODELS_DIR, enable_redaction=True):
""" """
Constructor Constructor
:param autoregressive_batch_size: Specifies how many samples to generate per batch. Lower this if you are seeing :param autoregressive_batch_size: Specifies how many samples to generate per batch. Lower this if you are seeing
@ -276,9 +278,9 @@ class TextToSpeech:
# Lazy-load the RLG models. # Lazy-load the RLG models.
if self.rlg_auto is None: if self.rlg_auto is None:
self.rlg_auto = RandomLatentConverter(1024).eval() self.rlg_auto = RandomLatentConverter(1024).eval()
self.rlg_auto.load_state_dict(torch.load('.models/rlg_auto.pth', map_location=torch.device('cpu'))) self.rlg_auto.load_state_dict(torch.load(os.path.join(MODELS_DIR, 'rlg_auto.pth'), map_location=torch.device('cpu')))
self.rlg_diffusion = RandomLatentConverter(2048).eval() self.rlg_diffusion = RandomLatentConverter(2048).eval()
self.rlg_diffusion.load_state_dict(torch.load('.models/rlg_diffuser.pth', map_location=torch.device('cpu'))) self.rlg_diffusion.load_state_dict(torch.load(os.path.join(MODELS_DIR, 'rlg_diffuser.pth'), map_location=torch.device('cpu')))
with torch.no_grad(): with torch.no_grad():
return self.rlg_auto(torch.tensor([0.0])), self.rlg_diffusion(torch.tensor([0.0])) return self.rlg_auto(torch.tensor([0.0])), self.rlg_diffusion(torch.tensor([0.0]))

View File

@ -3,8 +3,8 @@ import os
import torchaudio import torchaudio
from api import TextToSpeech from api import TextToSpeech, MODELS_DIR
from tortoise.utils.audio import load_audio, get_voices, load_voice from utils.audio import load_voice
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
@ -17,7 +17,7 @@ if __name__ == '__main__':
default=.5) default=.5)
parser.add_argument('--output_path', type=str, help='Where to store outputs.', default='results/') parser.add_argument('--output_path', type=str, help='Where to store outputs.', default='results/')
parser.add_argument('--model_dir', type=str, help='Where to find pretrained model checkpoints. Tortoise automatically downloads these to .models, so this' parser.add_argument('--model_dir', type=str, help='Where to find pretrained model checkpoints. Tortoise automatically downloads these to .models, so this'
'should only be specified if you have custom checkpoints.', default='.models') 'should only be specified if you have custom checkpoints.', default=MODELS_DIR)
parser.add_argument('--candidates', type=int, help='How many output candidates to produce per-voice.', default=3) parser.add_argument('--candidates', type=int, help='How many output candidates to produce per-voice.', default=3)
args = parser.parse_args() args = parser.parse_args()
os.makedirs(args.output_path, exist_ok=True) os.makedirs(args.output_path, exist_ok=True)

View File

@ -4,8 +4,8 @@ import os
import torch import torch
import torchaudio import torchaudio
from api import TextToSpeech from api import TextToSpeech, MODELS_DIR
from utils.audio import load_audio, get_voices, load_voices from utils.audio import load_audio, load_voices
from utils.text import split_and_recombine_text from utils.text import split_and_recombine_text
@ -21,7 +21,7 @@ if __name__ == '__main__':
help='How to balance vocal diversity with the quality/intelligibility of the spoken text. 0 means highly diverse voice (not recommended), 1 means maximize intellibility', help='How to balance vocal diversity with the quality/intelligibility of the spoken text. 0 means highly diverse voice (not recommended), 1 means maximize intellibility',
default=.5) default=.5)
parser.add_argument('--model_dir', type=str, help='Where to find pretrained model checkpoints. Tortoise automatically downloads these to .models, so this' parser.add_argument('--model_dir', type=str, help='Where to find pretrained model checkpoints. Tortoise automatically downloads these to .models, so this'
'should only be specified if you have custom checkpoints.', default='.models') 'should only be specified if you have custom checkpoints.', default=MODELS_DIR)
args = parser.parse_args() args = parser.parse_args()
tts = TextToSpeech(models_dir=args.model_dir) tts = TextToSpeech(models_dir=args.model_dir)

View File

@ -82,21 +82,23 @@ def dynamic_range_decompression(x, C=1):
return torch.exp(x) / C return torch.exp(x) / C
def get_voices(): def get_voices(extra_voice_dirs=[]):
subs = os.listdir('tortoise/voices') dirs = ['tortoise/voices'] + extra_voice_dirs
voices = {} voices = {}
for d in dirs:
subs = os.listdir(d)
for sub in subs: for sub in subs:
subj = os.path.join('tortoise/voices', sub) subj = os.path.join(d, sub)
if os.path.isdir(subj): if os.path.isdir(subj):
voices[sub] = list(glob(f'{subj}/*.wav')) + list(glob(f'{subj}/*.mp3')) + list(glob(f'{subj}/*.pth')) voices[sub] = list(glob(f'{subj}/*.wav')) + list(glob(f'{subj}/*.mp3')) + list(glob(f'{subj}/*.pth'))
return voices return voices
def load_voice(voice): def load_voice(voice, extra_voice_dirs=[]):
if voice == 'random': if voice == 'random':
return None, None return None, None
voices = get_voices() voices = get_voices(extra_voice_dirs)
paths = voices[voice] paths = voices[voice]
if len(paths) == 1 and paths[0].endswith('.pth'): if len(paths) == 1 and paths[0].endswith('.pth'):
return None, torch.load(paths[0]) return None, torch.load(paths[0])
@ -108,14 +110,14 @@ def load_voice(voice):
return conds, None return conds, None
def load_voices(voices): def load_voices(voices, extra_voice_dirs=[]):
latents = [] latents = []
clips = [] clips = []
for voice in voices: for voice in voices:
if voice == 'random': if voice == 'random':
print("Cannot combine a random voice with a non-random voice. Just using a random voice.") print("Cannot combine a random voice with a non-random voice. Just using a random voice.")
return None, None return None, None
clip, latent = load_voice(voice) clip, latent = load_voice(voice, extra_voice_dirs)
if latent is None: if latent is None:
assert len(latents) == 0, "Can only combine raw audio voices or latent voices, not both. Do it yourself if you want this." assert len(latents) == 0, "Can only combine raw audio voices or latent voices, not both. Do it yourself if you want this."
clips.extend(clip) clips.extend(clip)