update requirements and some docs

This commit is contained in:
James Betker 2022-04-21 16:06:43 -06:00
parent 9f28b005f3
commit 3735a819b3
3 changed files with 18 additions and 27 deletions

38
api.py
View File

@ -21,7 +21,12 @@ from utils.tokenizer import VoiceBpeTokenizer, lev_distance
pbar = None pbar = None
def download_models(): def download_models():
"""
Call to download all the models that Tortoise uses.
"""
MODELS = { MODELS = {
'autoregressive.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/hf/.models/autoregressive.pth', 'autoregressive.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/hf/.models/autoregressive.pth',
'clvp.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/hf/.models/clvp.pth', 'clvp.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/hf/.models/clvp.pth',
@ -51,6 +56,9 @@ def download_models():
def pad_or_truncate(t, length): def pad_or_truncate(t, length):
"""
Utility function for forcing <t> to have the specified sequence length, whether by clipping it or padding it with 0s.
"""
if t.shape[-1] == length: if t.shape[-1] == length:
return t return t
elif t.shape[-1] < length: elif t.shape[-1] < length:
@ -68,7 +76,10 @@ def load_discrete_vocoder_diffuser(trained_diffusion_steps=4000, desired_diffusi
conditioning_free=cond_free, conditioning_free_k=cond_free_k) conditioning_free=cond_free, conditioning_free_k=cond_free_k)
def load_conditioning(clip, cond_length=132300): def format_conditioning(clip, cond_length=132300):
"""
Converts the given conditioning signal to a MEL spectrogram and clips it as expected by the models.
"""
gap = clip.shape[-1] - cond_length gap = clip.shape[-1] - cond_length
if gap < 0: if gap < 0:
clip = F.pad(clip, pad=(0, abs(gap))) clip = F.pad(clip, pad=(0, abs(gap)))
@ -79,29 +90,6 @@ def load_conditioning(clip, cond_length=132300):
return mel_clip.unsqueeze(0).cuda() return mel_clip.unsqueeze(0).cuda()
def clip_guided_generation(autoregressive_model, clip_model, conditioning_input, text_input, num_batches, stop_mel_token,
tokens_per_clip_inference=10, clip_results_to_reduce_to=8, **generation_kwargs):
"""
Uses a CLVP model trained to associate full text with **partial** audio clips to pick the best generation candidates
every few iterations. The top results are then propagated forward through the generation process. Rinse and repeat.
This is a hybrid between beam search and sampling.
"""
token_goal = tokens_per_clip_inference
finished = False
while not finished and token_goal < autoregressive_model.max_mel_tokens:
samples = []
for b in tqdm(range(num_batches)):
codes = autoregressive_model.inference_speech(conditioning_input, text_input, **generation_kwargs)
samples.append(codes)
for batch in samples:
for i in range(batch.shape[0]):
batch[i] = fix_autoregressive_output(batch[i], stop_mel_token, complain=False)
clip_results.append(clip_model(text_input.repeat(batch.shape[0], 1), batch, return_loss=False))
clip_results = torch.cat(clip_results, dim=0)
samples = torch.cat(samples, dim=0)
best_results = samples[torch.topk(clip_results, k=clip_results_to_reduce_to).indices]
def fix_autoregressive_output(codes, stop_token, complain=True): def fix_autoregressive_output(codes, stop_token, complain=True):
""" """
This function performs some padding on coded audio that fixes a mismatch issue between what the diffusion model was This function performs some padding on coded audio that fixes a mismatch issue between what the diffusion model was
@ -222,7 +210,7 @@ class TextToSpeech:
if not isinstance(voice_samples, list): if not isinstance(voice_samples, list):
voice_samples = [voice_samples] voice_samples = [voice_samples]
for vs in voice_samples: for vs in voice_samples:
conds.append(load_conditioning(vs)) conds.append(format_conditioning(vs))
conds = torch.stack(conds, dim=1) conds = torch.stack(conds, dim=1)
diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=diffusion_iterations, cond_free=cond_free, cond_free_k=cond_free_k) diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=diffusion_iterations, cond_free=cond_free, cond_free_k=cond_free_k)

View File

@ -5,10 +5,11 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
import torchaudio import torchaudio
from api import TextToSpeech, load_conditioning from api import TextToSpeech, format_conditioning
from utils.audio import load_audio, get_voices from utils.audio import load_audio, get_voices
from utils.tokenizer import VoiceBpeTokenizer from utils.tokenizer import VoiceBpeTokenizer
def split_and_recombine_text(texts, desired_length=200, max_len=300): def split_and_recombine_text(texts, desired_length=200, max_len=300):
# TODO: also split across '!' and '?'. Attempt to keep quotations together. # TODO: also split across '!' and '?'. Attempt to keep quotations together.
texts = [s.strip() + "." for s in texts.split('.')] texts = [s.strip() + "." for s in texts.split('.')]
@ -26,6 +27,7 @@ def split_and_recombine_text(texts, desired_length=200, max_len=300):
texts.pop(i+1) texts.pop(i+1)
return texts return texts
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--textfile', type=str, help='A file containing the text to read.', default="data/riding_hood.txt") parser.add_argument('--textfile', type=str, help='A file containing the text to read.', default="data/riding_hood.txt")

View File

@ -7,3 +7,4 @@ inflect
progressbar progressbar
einops einops
unidecode unidecode
entmax