forked from mrq/tortoise-tts
Allow setting models path from environment variable
This commit is contained in:
parent
b5fc8f198b
commit
20220893af
|
@ -25,6 +25,7 @@ from tortoise.utils.wav2vec_alignment import Wav2VecAlignment
|
||||||
|
|
||||||
pbar = None
|
pbar = None
|
||||||
|
|
||||||
|
MODELS_DIR = os.environ.get('TORTOISE_MODELS_DIR', '.models')
|
||||||
|
|
||||||
def download_models(specific_models=None):
|
def download_models(specific_models=None):
|
||||||
"""
|
"""
|
||||||
|
@ -40,7 +41,7 @@ def download_models(specific_models=None):
|
||||||
'rlg_auto.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/rlg_auto.pth',
|
'rlg_auto.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/rlg_auto.pth',
|
||||||
'rlg_diffuser.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/rlg_diffuser.pth',
|
'rlg_diffuser.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/rlg_diffuser.pth',
|
||||||
}
|
}
|
||||||
os.makedirs('.models', exist_ok=True)
|
os.makedirs(MODELS_DIR, exist_ok=True)
|
||||||
def show_progress(block_num, block_size, total_size):
|
def show_progress(block_num, block_size, total_size):
|
||||||
global pbar
|
global pbar
|
||||||
if pbar is None:
|
if pbar is None:
|
||||||
|
@ -56,10 +57,11 @@ def download_models(specific_models=None):
|
||||||
for model_name, url in MODELS.items():
|
for model_name, url in MODELS.items():
|
||||||
if specific_models is not None and model_name not in specific_models:
|
if specific_models is not None and model_name not in specific_models:
|
||||||
continue
|
continue
|
||||||
if os.path.exists(f'.models/{model_name}'):
|
model_path = os.path.join(MODELS_DIR, model_name)
|
||||||
|
if os.path.exists(model_path):
|
||||||
continue
|
continue
|
||||||
print(f'Downloading {model_name} from {url}...')
|
print(f'Downloading {model_name} from {url}...')
|
||||||
request.urlretrieve(url, f'.models/{model_name}', show_progress)
|
request.urlretrieve(url, model_path, show_progress)
|
||||||
print('Done.')
|
print('Done.')
|
||||||
|
|
||||||
|
|
||||||
|
@ -154,7 +156,7 @@ def classify_audio_clip(clip):
|
||||||
classifier = AudioMiniEncoderWithClassifierHead(2, spec_dim=1, embedding_dim=512, depth=5, downsample_factor=4,
|
classifier = AudioMiniEncoderWithClassifierHead(2, spec_dim=1, embedding_dim=512, depth=5, downsample_factor=4,
|
||||||
resnet_blocks=2, attn_blocks=4, num_attn_heads=4, base_channels=32,
|
resnet_blocks=2, attn_blocks=4, num_attn_heads=4, base_channels=32,
|
||||||
dropout=0, kernel_size=5, distribute_zero_label=False)
|
dropout=0, kernel_size=5, distribute_zero_label=False)
|
||||||
classifier.load_state_dict(torch.load('.models/classifier.pth', map_location=torch.device('cpu')))
|
classifier.load_state_dict(torch.load(os.path.join(MODELS_DIR, 'classifier.pth'), map_location=torch.device('cpu')))
|
||||||
clip = clip.cpu().unsqueeze(0)
|
clip = clip.cpu().unsqueeze(0)
|
||||||
results = F.softmax(classifier(clip), dim=-1)
|
results = F.softmax(classifier(clip), dim=-1)
|
||||||
return results[0][0]
|
return results[0][0]
|
||||||
|
@ -181,7 +183,7 @@ class TextToSpeech:
|
||||||
Main entry point into Tortoise.
|
Main entry point into Tortoise.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, autoregressive_batch_size=None, models_dir='.models', enable_redaction=True):
|
def __init__(self, autoregressive_batch_size=None, models_dir=MODELS_DIR, enable_redaction=True):
|
||||||
"""
|
"""
|
||||||
Constructor
|
Constructor
|
||||||
:param autoregressive_batch_size: Specifies how many samples to generate per batch. Lower this if you are seeing
|
:param autoregressive_batch_size: Specifies how many samples to generate per batch. Lower this if you are seeing
|
||||||
|
@ -276,9 +278,9 @@ class TextToSpeech:
|
||||||
# Lazy-load the RLG models.
|
# Lazy-load the RLG models.
|
||||||
if self.rlg_auto is None:
|
if self.rlg_auto is None:
|
||||||
self.rlg_auto = RandomLatentConverter(1024).eval()
|
self.rlg_auto = RandomLatentConverter(1024).eval()
|
||||||
self.rlg_auto.load_state_dict(torch.load('.models/rlg_auto.pth', map_location=torch.device('cpu')))
|
self.rlg_auto.load_state_dict(torch.load(os.path.join(MODELS_DIR, 'rlg_auto.pth'), map_location=torch.device('cpu')))
|
||||||
self.rlg_diffusion = RandomLatentConverter(2048).eval()
|
self.rlg_diffusion = RandomLatentConverter(2048).eval()
|
||||||
self.rlg_diffusion.load_state_dict(torch.load('.models/rlg_diffuser.pth', map_location=torch.device('cpu')))
|
self.rlg_diffusion.load_state_dict(torch.load(os.path.join(MODELS_DIR, 'rlg_diffuser.pth'), map_location=torch.device('cpu')))
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
return self.rlg_auto(torch.tensor([0.0])), self.rlg_diffusion(torch.tensor([0.0]))
|
return self.rlg_auto(torch.tensor([0.0])), self.rlg_diffusion(torch.tensor([0.0]))
|
||||||
|
|
||||||
|
|
|
@ -3,8 +3,8 @@ import os
|
||||||
|
|
||||||
import torchaudio
|
import torchaudio
|
||||||
|
|
||||||
from api import TextToSpeech
|
from api import TextToSpeech, MODELS_DIR
|
||||||
from tortoise.utils.audio import load_audio, get_voices, load_voice
|
from utils.audio import load_voice
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
@ -17,7 +17,7 @@ if __name__ == '__main__':
|
||||||
default=.5)
|
default=.5)
|
||||||
parser.add_argument('--output_path', type=str, help='Where to store outputs.', default='results/')
|
parser.add_argument('--output_path', type=str, help='Where to store outputs.', default='results/')
|
||||||
parser.add_argument('--model_dir', type=str, help='Where to find pretrained model checkpoints. Tortoise automatically downloads these to .models, so this'
|
parser.add_argument('--model_dir', type=str, help='Where to find pretrained model checkpoints. Tortoise automatically downloads these to .models, so this'
|
||||||
'should only be specified if you have custom checkpoints.', default='.models')
|
'should only be specified if you have custom checkpoints.', default=MODELS_DIR)
|
||||||
parser.add_argument('--candidates', type=int, help='How many output candidates to produce per-voice.', default=3)
|
parser.add_argument('--candidates', type=int, help='How many output candidates to produce per-voice.', default=3)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
os.makedirs(args.output_path, exist_ok=True)
|
os.makedirs(args.output_path, exist_ok=True)
|
||||||
|
|
|
@ -4,8 +4,8 @@ import os
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
|
|
||||||
from api import TextToSpeech
|
from api import TextToSpeech, MODELS_DIR
|
||||||
from utils.audio import load_audio, get_voices, load_voices
|
from utils.audio import load_audio, load_voices
|
||||||
from utils.text import split_and_recombine_text
|
from utils.text import split_and_recombine_text
|
||||||
|
|
||||||
|
|
||||||
|
@ -21,7 +21,7 @@ if __name__ == '__main__':
|
||||||
help='How to balance vocal diversity with the quality/intelligibility of the spoken text. 0 means highly diverse voice (not recommended), 1 means maximize intellibility',
|
help='How to balance vocal diversity with the quality/intelligibility of the spoken text. 0 means highly diverse voice (not recommended), 1 means maximize intellibility',
|
||||||
default=.5)
|
default=.5)
|
||||||
parser.add_argument('--model_dir', type=str, help='Where to find pretrained model checkpoints. Tortoise automatically downloads these to .models, so this'
|
parser.add_argument('--model_dir', type=str, help='Where to find pretrained model checkpoints. Tortoise automatically downloads these to .models, so this'
|
||||||
'should only be specified if you have custom checkpoints.', default='.models')
|
'should only be specified if you have custom checkpoints.', default=MODELS_DIR)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
tts = TextToSpeech(models_dir=args.model_dir)
|
tts = TextToSpeech(models_dir=args.model_dir)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user