forked from mrq/tortoise-tts
Merge remote-tracking branch 'origin/main'
# Conflicts: # tortoise/read.py
This commit is contained in:
commit
a1c131bde9
|
@ -25,6 +25,7 @@ from tortoise.utils.wav2vec_alignment import Wav2VecAlignment
|
|||
|
||||
pbar = None
|
||||
|
||||
MODELS_DIR = os.environ.get('TORTOISE_MODELS_DIR', '.models')
|
||||
|
||||
def download_models(specific_models=None):
|
||||
"""
|
||||
|
@ -39,7 +40,7 @@ def download_models(specific_models=None):
|
|||
'rlg_auto.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/rlg_auto.pth',
|
||||
'rlg_diffuser.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/rlg_diffuser.pth',
|
||||
}
|
||||
os.makedirs('.models', exist_ok=True)
|
||||
os.makedirs(MODELS_DIR, exist_ok=True)
|
||||
def show_progress(block_num, block_size, total_size):
|
||||
global pbar
|
||||
if pbar is None:
|
||||
|
@ -55,10 +56,11 @@ def download_models(specific_models=None):
|
|||
for model_name, url in MODELS.items():
|
||||
if specific_models is not None and model_name not in specific_models:
|
||||
continue
|
||||
if os.path.exists(f'.models/{model_name}'):
|
||||
model_path = os.path.join(MODELS_DIR, model_name)
|
||||
if os.path.exists(model_path):
|
||||
continue
|
||||
print(f'Downloading {model_name} from {url}...')
|
||||
request.urlretrieve(url, f'.models/{model_name}', show_progress)
|
||||
request.urlretrieve(url, model_path, show_progress)
|
||||
print('Done.')
|
||||
|
||||
|
||||
|
@ -153,7 +155,7 @@ def classify_audio_clip(clip):
|
|||
classifier = AudioMiniEncoderWithClassifierHead(2, spec_dim=1, embedding_dim=512, depth=5, downsample_factor=4,
|
||||
resnet_blocks=2, attn_blocks=4, num_attn_heads=4, base_channels=32,
|
||||
dropout=0, kernel_size=5, distribute_zero_label=False)
|
||||
classifier.load_state_dict(torch.load('.models/classifier.pth', map_location=torch.device('cpu')))
|
||||
classifier.load_state_dict(torch.load(os.path.join(MODELS_DIR, 'classifier.pth'), map_location=torch.device('cpu')))
|
||||
clip = clip.cpu().unsqueeze(0)
|
||||
results = F.softmax(classifier(clip), dim=-1)
|
||||
return results[0][0]
|
||||
|
@ -180,7 +182,7 @@ class TextToSpeech:
|
|||
Main entry point into Tortoise.
|
||||
"""
|
||||
|
||||
def __init__(self, autoregressive_batch_size=None, models_dir='.models', enable_redaction=True):
|
||||
def __init__(self, autoregressive_batch_size=None, models_dir=MODELS_DIR, enable_redaction=True):
|
||||
"""
|
||||
Constructor
|
||||
:param autoregressive_batch_size: Specifies how many samples to generate per batch. Lower this if you are seeing
|
||||
|
@ -271,9 +273,9 @@ class TextToSpeech:
|
|||
# Lazy-load the RLG models.
|
||||
if self.rlg_auto is None:
|
||||
self.rlg_auto = RandomLatentConverter(1024).eval()
|
||||
self.rlg_auto.load_state_dict(torch.load('.models/rlg_auto.pth', map_location=torch.device('cpu')))
|
||||
self.rlg_auto.load_state_dict(torch.load(os.path.join(MODELS_DIR, 'rlg_auto.pth'), map_location=torch.device('cpu')))
|
||||
self.rlg_diffusion = RandomLatentConverter(2048).eval()
|
||||
self.rlg_diffusion.load_state_dict(torch.load('.models/rlg_diffuser.pth', map_location=torch.device('cpu')))
|
||||
self.rlg_diffusion.load_state_dict(torch.load(os.path.join(MODELS_DIR, 'rlg_diffuser.pth'), map_location=torch.device('cpu')))
|
||||
with torch.no_grad():
|
||||
return self.rlg_auto(torch.tensor([0.0])), self.rlg_diffusion(torch.tensor([0.0]))
|
||||
|
||||
|
|
|
@ -4,8 +4,8 @@ import os
|
|||
import torch
|
||||
import torchaudio
|
||||
|
||||
from api import TextToSpeech
|
||||
from tortoise.utils.audio import load_audio, get_voices, load_voice
|
||||
from api import TextToSpeech, MODELS_DIR
|
||||
from utils.audio import load_voice
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
|
@ -15,7 +15,7 @@ if __name__ == '__main__':
|
|||
parser.add_argument('--preset', type=str, help='Which voice preset to use.', default='fast')
|
||||
parser.add_argument('--output_path', type=str, help='Where to store outputs.', default='results/')
|
||||
parser.add_argument('--model_dir', type=str, help='Where to find pretrained model checkpoints. Tortoise automatically downloads these to .models, so this'
|
||||
'should only be specified if you have custom checkpoints.', default='.models')
|
||||
'should only be specified if you have custom checkpoints.', default=MODELS_DIR)
|
||||
parser.add_argument('--candidates', type=int, help='How many output candidates to produce per-voice.', default=3)
|
||||
parser.add_argument('--seed', type=int, help='Random seed which can be used to reproduce results.', default=None)
|
||||
parser.add_argument('--produce_debug_state', type=bool, help='Whether or not to produce debug_state.pth, which can aid in reproducing problems. Defaults to true.', default=True)
|
||||
|
|
|
@ -5,8 +5,8 @@ from time import time
|
|||
import torch
|
||||
import torchaudio
|
||||
|
||||
from api import TextToSpeech
|
||||
from utils.audio import load_audio, get_voices, load_voices
|
||||
from api import TextToSpeech, MODELS_DIR
|
||||
from utils.audio import load_audio, load_voices
|
||||
from utils.text import split_and_recombine_text
|
||||
|
||||
|
||||
|
@ -19,10 +19,14 @@ if __name__ == '__main__':
|
|||
parser.add_argument('--preset', type=str, help='Which voice preset to use.', default='standard')
|
||||
parser.add_argument('--regenerate', type=str, help='Comma-separated list of clip numbers to re-generate, or nothing.', default=None)
|
||||
parser.add_argument('--model_dir', type=str, help='Where to find pretrained model checkpoints. Tortoise automatically downloads these to .models, so this'
|
||||
<<<<<<< HEAD
|
||||
'should only be specified if you have custom checkpoints.', default='.models')
|
||||
parser.add_argument('--seed', type=int, help='Random seed which can be used to reproduce results.', default=None)
|
||||
parser.add_argument('--produce_debug_state', type=bool, help='Whether or not to produce debug_state.pth, which can aid in reproducing problems. Defaults to true.', default=True)
|
||||
|
||||
=======
|
||||
'should only be specified if you have custom checkpoints.', default=MODELS_DIR)
|
||||
>>>>>>> origin/main
|
||||
args = parser.parse_args()
|
||||
tts = TextToSpeech(models_dir=args.model_dir)
|
||||
|
||||
|
|
|
@ -82,21 +82,23 @@ def dynamic_range_decompression(x, C=1):
|
|||
return torch.exp(x) / C
|
||||
|
||||
|
||||
def get_voices():
|
||||
subs = os.listdir('tortoise/voices')
|
||||
def get_voices(extra_voice_dirs=[]):
|
||||
dirs = ['tortoise/voices'] + extra_voice_dirs
|
||||
voices = {}
|
||||
for sub in subs:
|
||||
subj = os.path.join('tortoise/voices', sub)
|
||||
if os.path.isdir(subj):
|
||||
voices[sub] = list(glob(f'{subj}/*.wav')) + list(glob(f'{subj}/*.mp3')) + list(glob(f'{subj}/*.pth'))
|
||||
for d in dirs:
|
||||
subs = os.listdir(d)
|
||||
for sub in subs:
|
||||
subj = os.path.join(d, sub)
|
||||
if os.path.isdir(subj):
|
||||
voices[sub] = list(glob(f'{subj}/*.wav')) + list(glob(f'{subj}/*.mp3')) + list(glob(f'{subj}/*.pth'))
|
||||
return voices
|
||||
|
||||
|
||||
def load_voice(voice):
|
||||
def load_voice(voice, extra_voice_dirs=[]):
|
||||
if voice == 'random':
|
||||
return None, None
|
||||
|
||||
voices = get_voices()
|
||||
voices = get_voices(extra_voice_dirs)
|
||||
paths = voices[voice]
|
||||
if len(paths) == 1 and paths[0].endswith('.pth'):
|
||||
return None, torch.load(paths[0])
|
||||
|
@ -108,14 +110,14 @@ def load_voice(voice):
|
|||
return conds, None
|
||||
|
||||
|
||||
def load_voices(voices):
|
||||
def load_voices(voices, extra_voice_dirs=[]):
|
||||
latents = []
|
||||
clips = []
|
||||
for voice in voices:
|
||||
if voice == 'random':
|
||||
print("Cannot combine a random voice with a non-random voice. Just using a random voice.")
|
||||
return None, None
|
||||
clip, latent = load_voice(voice)
|
||||
clip, latent = load_voice(voice, extra_voice_dirs)
|
||||
if latent is None:
|
||||
assert len(latents) == 0, "Can only combine raw audio voices or latent voices, not both. Do it yourself if you want this."
|
||||
clips.extend(clip)
|
||||
|
|
Loading…
Reference in New Issue
Block a user