From da31baad21d9bfd0ab47733114b82c3076f405ec Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 21 Apr 2022 16:06:43 -0600 Subject: [PATCH] update requirements and some docs --- api.py | 38 +++++++++++++------------------------- read.py | 4 +++- requirements.txt | 3 ++- 3 files changed, 18 insertions(+), 27 deletions(-) diff --git a/api.py b/api.py index 80aae16..8486d3f 100644 --- a/api.py +++ b/api.py @@ -21,7 +21,12 @@ from utils.tokenizer import VoiceBpeTokenizer, lev_distance pbar = None + + def download_models(): + """ + Call to download all the models that Tortoise uses. + """ MODELS = { 'autoregressive.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/hf/.models/autoregressive.pth', 'clvp.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/hf/.models/clvp.pth', @@ -51,6 +56,9 @@ def download_models(): def pad_or_truncate(t, length): + """ + Utility function for forcing to have the specified sequence length, whether by clipping it or padding it with 0s. + """ if t.shape[-1] == length: return t elif t.shape[-1] < length: @@ -68,7 +76,10 @@ def load_discrete_vocoder_diffuser(trained_diffusion_steps=4000, desired_diffusi conditioning_free=cond_free, conditioning_free_k=cond_free_k) -def load_conditioning(clip, cond_length=132300): +def format_conditioning(clip, cond_length=132300): + """ + Converts the given conditioning signal to a MEL spectrogram and clips it as expected by the models. + """ gap = clip.shape[-1] - cond_length if gap < 0: clip = F.pad(clip, pad=(0, abs(gap))) @@ -79,29 +90,6 @@ def load_conditioning(clip, cond_length=132300): return mel_clip.unsqueeze(0).cuda() -def clip_guided_generation(autoregressive_model, clip_model, conditioning_input, text_input, num_batches, stop_mel_token, - tokens_per_clip_inference=10, clip_results_to_reduce_to=8, **generation_kwargs): - """ - Uses a CLVP model trained to associate full text with **partial** audio clips to pick the best generation candidates - every few iterations. The top results are then propagated forward through the generation process. Rinse and repeat. - This is a hybrid between beam search and sampling. - """ - token_goal = tokens_per_clip_inference - finished = False - while not finished and token_goal < autoregressive_model.max_mel_tokens: - samples = [] - for b in tqdm(range(num_batches)): - codes = autoregressive_model.inference_speech(conditioning_input, text_input, **generation_kwargs) - samples.append(codes) - for batch in samples: - for i in range(batch.shape[0]): - batch[i] = fix_autoregressive_output(batch[i], stop_mel_token, complain=False) - clip_results.append(clip_model(text_input.repeat(batch.shape[0], 1), batch, return_loss=False)) - clip_results = torch.cat(clip_results, dim=0) - samples = torch.cat(samples, dim=0) - best_results = samples[torch.topk(clip_results, k=clip_results_to_reduce_to).indices] - - def fix_autoregressive_output(codes, stop_token, complain=True): """ This function performs some padding on coded audio that fixes a mismatch issue between what the diffusion model was @@ -222,7 +210,7 @@ class TextToSpeech: if not isinstance(voice_samples, list): voice_samples = [voice_samples] for vs in voice_samples: - conds.append(load_conditioning(vs)) + conds.append(format_conditioning(vs)) conds = torch.stack(conds, dim=1) diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=diffusion_iterations, cond_free=cond_free, cond_free_k=cond_free_k) diff --git a/read.py b/read.py index 9e0bc0c..04cd7a0 100644 --- a/read.py +++ b/read.py @@ -5,10 +5,11 @@ import torch import torch.nn.functional as F import torchaudio -from api import TextToSpeech, load_conditioning +from api import TextToSpeech, format_conditioning from utils.audio import load_audio, get_voices from utils.tokenizer import VoiceBpeTokenizer + def split_and_recombine_text(texts, desired_length=200, max_len=300): # TODO: also split across '!' and '?'. Attempt to keep quotations together. texts = [s.strip() + "." for s in texts.split('.')] @@ -26,6 +27,7 @@ def split_and_recombine_text(texts, desired_length=200, max_len=300): texts.pop(i+1) return texts + if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--textfile', type=str, help='A file containing the text to read.', default="data/riding_hood.txt") diff --git a/requirements.txt b/requirements.txt index 880c033..b971e61 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,4 +6,5 @@ tokenizers inflect progressbar einops -unidecode \ No newline at end of file +unidecode +entmax \ No newline at end of file