added button to refresh voice list, enabling KV caching for a bonerific speed increase (credit to https://github.com/152334H/tortoise-tts-fast/)

remotes/1710189933836426429/master
mrq 2023-02-05 17:59:13 +07:00
parent 7b767e1442
commit daebc6c21c
8 changed files with 67 additions and 37 deletions

@ -1,6 +1,6 @@
# AI Voice Cloning for Retards and Savants
This [rentry](https://rentry.org/AI-Voice-Cloning/) aims to serve as both a foolproof guide for setting up AI voice cloning tools for legitimate, local use on Windows (with an Nvidia GPU), as well as a stepping stone for anons that genuinely want to play around with TorToiSe.
This [rentry](https://rentry.org/AI-Voice-Cloning/) aims to serve as both a foolproof guide for setting up AI voice cloning tools for legitimate, local use on Windows (with an Nvidia GPU), as well as a stepping stone for anons that genuinely want to play around with [TorToiSe](https://github.com/neonbjb/tortoise-tts).
Similar to my own findings for Stable Diffusion image generation, this rentry may appear a little disheveled as I note my new findings with TorToiSe. Please keep this in mind if the guide seems to shift a bit or sound confusing.
@ -12,6 +12,8 @@ I link those a bit later on as alternatives for Windows+AMD users. You're free t
I'm extremely lazy and can't be assed to install Arch Linux again, much less create shell script equivalents. The commands should be almost 1:1 with what's in the batch file, save for the line to activate the venv.
I leave this as an exercise to the Linux reader.
>\>Ugh... why bother when I can just abuse 11.AI?
I very much encourage (You) to use 11.AI while it's still viable to use. For the layman, it's easier to go through the hoops of coughing up the $5 or abusing the free trial over actually setting up a TorToiSe environment and dealing with its quirks.
@ -26,8 +28,6 @@ For setting up on Linux, the general framework should be the same, but left as a
For Windows users with an AMD GPU, tough luck, as ROCm drivers are not (easily) available for Windows, and requires inane patches with PyTorch. Consider using the [Colab notebook](https://colab.research.google.com/drive/1wVVqUPqwiDBUVeWWOUNglpGhU3hg_cbR?usp=sharing), or the [Hugging Face space](https://huggingface.co/spaces/mdnestor/tortoise), for `tortoise-tts`.
Lots of available RAM seems to be a requirement, as I see Python eating up 8GiB for generations, and if I'm not careful I'll get OOM errors from the software, so be cautious of memory problems if you're doing other things while it runs in the background. For long text generations, you might also exhaust your available VRAM with how the software automatically calculates batch size (for example, a 6GiB of VRAM card using 4GiB for the autoregressive sampling step, but the CLVP matching step requiring more than what's available).
### Pre-Requirements
Python 3.9: https://www.python.org/downloads/release/python-3913/
@ -42,45 +42,49 @@ After installing python, open the Start Menu and search for `Command Prompt`. Ty
Paste `git clone https://git.ecker.tech/mrq/tortoise-tts` to download TorToiSe and additional scripts. Inexperienced users can just download the repo as a ZIP, and extract.
Then move into that folder with `cd tortoise-tts`. Afterwards, enter `setup.bat` to automatically enter all the remaining commands.
Afterwards, run `setup.bat` to automatically set things up.
If you've done everything right, you shouldn't have any errors.
### Updating
To check for updates with the Web UI, simply enter `git pull` in the command prompt, while the TorToiSe workspace is the current working directory.
To check for updates, simply run `update.bat`. It should pull from the repo, as well as fetch for any new dependencies.
### Pitfalls You May Encounter
I'll try and make a list of "common" (or what I feel may be common that I experience) issues with getting TorToiSe set up:
* `failed reading zip archive: failed finding central directory`: You had a file fail to download completely during the model downloading initialization phase. Please open either `%USERPROFILE%\.cache\tortoise\models\` or `%USERPROFILE%\.cache\huggingface\models\`, and delete the offending file.
* `failed reading zip archive: failed finding central directory`: You had a file fail to download completely during the model downloading initialization phase. Please open either `.\models\tortoise\` or `.\models\transformers\`, and delete the offending file.
You can deduce what that file is by reading the stack trace. A few lines above the last like will be a line trying to read a model path.
* `torch.cuda.OutOfMemoryError: CUDA out of memory.`: You most likely have a GPU with low VRAM (~4GiB), and the small optimizations with keeping data on the GPU is enough to OOM. Please open the `start.bat` file and add `--low-vram` to the command (for example: `py app.py --low-vram`) to disable those small optimizations.
## Preparing Voice Samples
Now that the tough part is dealt with, it's time to prepare voice sample clips to use.
Now that the tough part is dealt with, it's time to prepare voice clips to use.
Unlike training embeddings for AI image generations, preparing a "dataset" for voice cloning is very simple. While the repo suggests using short clips of about ten seconds each, you aren't required to manually snip them up. I'm not sure which way is "better", as some voices work perfectly fine with two clips with minutes each worth of audio, while other voices work better with ten short clips.
As a general rule of thumb, try to source clips that aren't noisy, and are entirely just the subject you are trying to clone. If you must, run your source sample through a background music/noise remover (how to is an exercise left to the reader). It isn't entirely a detriment if you're unable to provide clean audio, however. Just be wary that you might have some headaches with getting acceptable output.
As a general rule of thumb, try to source clips that aren't noisy, and are entirely just the subject you are trying to clone. If you must, run your source through a background music/noise remover (how to is an exercise left to the reader). It isn't entirely a detriment if you're unable to provide clean audio, however. Just be wary that you might have some headaches with getting acceptable output.
After sourcing your clips, you have two options:
* use all of your samples for voice cloning, providing as much coverage for whatever you may want
* isolate the best of your samples into a few clips (around ten clips each of about ten seconds each), focusing on samples that best match what you're looking to get out of it
After sourcing some clips, here are some considerations whether you should narrow down the pool you used, or not:
* if you're aiming for a specific delivery (for example, having a line re-read but with word(s) replaced), use just that clip with the line. If you want to err on the side of caution, you can add one more similar clip for safety.
* if your source clips are all delivered in a similar manner (for example, the Patrick Bateman example provided later), it's not necessary to cull.
* if you're hoping to generate something non-specific, you're free to just use your entire pool.
Either methods work, but some workloads tend to favor one over the other. If you're running out of options on improving overall cloning quality, consider switching to the other method. In my opinion, the first one seems to work better overall, and rely on other means of improving the quality of cloning.
There's no hard specifics on how many, or how long, your sources should be.
If you're looking to trim your clips, in my opinion, ~~Audacity~~ Tenacity works good enough, as you can easily output your clips into the proper format (22050 Hz sampling rate, 32-bit float encoding), but some of the time, the software will print out some warning message (`WavFileWarning: Chunk (non-data) not understood, skipping it.`), it's safe to assume you need to properly remux it with `ffmpeg`, simply with `ffmpeg -i [input] -ar 22050 -c:a pcm_f32le [output].wav`. Power users can use the previous command instead of relying on Tenacity to remux.
After sourcing your clips, there are some considerations on how to narrow down your voice clips, if needed:
* if you're aiming for a specific delivery (for example, having a line re-read but with word(s) replaced), use just that clip with the line. If you want to err on the side of caution, you can add one more similar clip for safety.
* if you're aiming to generate a wide range of lines, you shouldn't have to worry about culling for similar clips, and you can just dump them all in for use
To me, there's no noticeable difference between combining them into one file, or keeping them all separated (outside of the initial load for a ton of files).
After preparing your clips as WAV files at a sample rate of 22050 Hz, open up the `tortoise-tts` folder you're working in, navigate to `./tortoise/voice/`, create a new folder in whatever name you want, then dump your clips into that folder. While you're in the `voice` folder, you can take a look at the other provided voices.
If you're looking to trim your clips, in my opinion, ~~Audacity~~ Tenacity works good enough, as you can easily output your clips into the proper format (22050 Hz sampling rate), but some of the time, the software will print out some (sometimes harmless, sometimes harmful) warning message (`WavFileWarning: Chunk (non-data) not understood, skipping it.`), it's safe to assume you need to properly remux it with `ffmpeg`, simply with `ffmpeg -i [input] -ar 22050 -c:a pcm_f32le [output].wav`. Power users can use the previous command instead of relying on Tenacity to remux.
**!**NOTE**!**: having a ton of files, regardless of size, substantially increases the time it takes to initialize the voice. I've had it take a while to load 227 or so samples of SA2 Shadow this way. Consider combining them all in one file through Tenacity, with dropping all of your audio files, then Select > Tracks > All, then Tracks > Align Tracks > Align End to End, then exporting the WAV. This does not introduce padding, however.
After preparing your clips as WAV files at a sample rate of 22050 Hz, open up the `tortoise-tts` folder you're working in, navigate to `./tortoise/voice/`, create a new folder in whatever name you want, then dump your clips into that folder. While you're in the `voice` folder, you can take a look at the other provided voices.
## Using the Software
Now you're ready to generate clips. With the command prompt still open, simply enter `start.bat`, and wait for it to print out a URL to open in your browser, something like `http://127.0.0.1:7861`.
Now you're ready to generate clips. With the command prompt still open, simply enter `start.bat`, and wait for it to print out a URL to open in your browser, something like `http://127.0.0.1:7860`.
If you're looking to access your copy of TorToiSe from outside your local network, pass `--share` into the command (for example, `python app.py --share`). You'll get a temporary gradio link to use.
@ -88,7 +92,7 @@ You'll be presented with a bunch of options, but do not be overwhelmed, as most
* `Prompt`: text you want to be read. You wrap text in `[brackets]` for "prompt engineering", where it'll affect the output, but those words won't actually be read.
* `Line Delimiter`: String to split the prompt into pieces. The stitched clip will be stored as `combined.wav`
- Setting this to `\n` will generate each line as one clip before stitching it. Leave blank to disable this.
* `Emotion`: the "emotion" used for the delivery. This is a shortcut to utilizing "prompt engineering" by starting with `[I am really <emotion>,]` in your prompt. This is not a guarantee, however.
* `Emotion`: the "emotion" used for the delivery. This is a shortcut to utilizing "prompt engineering" by starting with `[I am really <emotion>,]` in your prompt. This is merely a suggestion, not a guarantee.
* `Custom Emotion + Prompt`: a non-preset "emotion" used for the delivery. This is a shortcut to utilizing "prompt engineering" by starting with `[<emotion>]` in your prompt.
* `Voice`: the voice you want to clone. You can select `microphone` if you want to use input from your microphone.
* `Microphone Source`: Use your own voice from a line-in source.
@ -100,6 +104,8 @@ You'll be presented with a bunch of options, but do not be overwhelmed, as most
* `Temperature`: how much randomness to introduce to the generated samples. Lower values = better resemblance to the source samples, but some temperature is still required for great output. This value is very inconsistent and entirely depends on the input voice. In other words, some voices will be receptive to playing with this value, while others won't make much of a difference.
* `Pause Size`: Governs how large pauses are at the end of a clip (in token size, not seconds). Increase this if your output gets cut off at the end.
* `Diffusion Sampler`: sampler method during the diffusion pass. Currently, only `P` and `DDIM` are added, but does not seem to offer any substantial differences in my short tests.
`P` refers to the default, vanilla sampling method in `diffusion.py`.
To reiterate, this ***only*** is useful for the diffusion decoding path, after the autoregressive outputs are generated.
After you fill everything out, click `Run`, and wait for your output in the output window. The sampled voice is also returned, but if you're using multiple files, it'll return the first file, rather than a combined file.

@ -22,6 +22,7 @@ def generate(text, delimiter, emotion, prompt, voice, mic_audio, preset, seed, c
mic = load_audio(mic_audio, 22050)
voice_samples, conditioning_latents = [mic], None
else:
progress(0, desc="Loading voice...")
voice_samples, conditioning_latents = load_voice(voice)
if voice_samples is not None:
@ -151,6 +152,9 @@ def update_presets(value):
else:
return (gr.update(), gr.update())
def update_voices():
return gr.Dropdown.update(choices=os.listdir(os.path.join("tortoise", "voices")) + ["microphone"])
def main():
with gr.Blocks() as demo:
with gr.Row():
@ -176,6 +180,11 @@ def main():
source="microphone",
type="filepath",
)
refresh_voices = gr.Button(value="Refresh Voice List")
refresh_voices.click(update_voices,
inputs=None,
outputs=voice
)
prompt.change(fn=lambda value: gr.update(value="Custom"),
inputs=prompt,

@ -0,0 +1,4 @@
@echo off
rm .\in\.gitkeep
rm .\out\.gitkeep
for %%a in (".\in\*.*") do ffmpeg -i "%%a" -ar 22050 -ac 1 -c:a pcm_f32le ".\out\%%~na.wav"

@ -169,13 +169,8 @@ def do_spectrogram_diffusion(diffusion_model, diffuser, latents, conditioning_la
noise = torch.randn(output_shape, device=latents.device) * temperature
mel = None
if sampler == "P":
mel = diffuser.p_sample_loop(diffusion_model, output_shape, noise=noise,
model_kwargs={'precomputed_aligned_embeddings': precomputed_embeddings},
verbose=verbose, progress=progress, desc=desc)
elif sampler == "DDIM":
mel = diffuser.ddim_sample_loop(diffusion_model, output_shape, noise=noise,
diffuser.sampler = sampler.lower()
mel = diffuser.ddim_sample_loop(diffusion_model, output_shape, noise=noise,
model_kwargs={'precomputed_aligned_embeddings': precomputed_embeddings},
verbose=verbose, progress=progress, desc=desc)
@ -251,6 +246,7 @@ class TextToSpeech:
heads=16, number_text_tokens=255, start_text_token=255, checkpointing=False,
train_solo_embeddings=False).cpu().eval()
self.autoregressive.load_state_dict(torch.load(get_model_path('autoregressive.pth', models_dir)))
self.autoregressive.post_init_gpt2_config(kv_cache=minor_optimizations)
self.diffusion = DiffusionTts(model_channels=1024, num_layers=10, in_channels=100, out_channels=200,
in_latent_channels=1024, in_tokens=8193, dropout=0, use_fp16=False, num_heads=16,

@ -33,13 +33,15 @@ class ResBlock(nn.Module):
class GPT2InferenceModel(GPT2PreTrainedModel):
def __init__(self, config, gpt, text_pos_emb, embeddings, norm, linear):
def __init__(self, config, gpt, text_pos_emb, embeddings, norm, linear, kv_cache):
super().__init__(config)
self.transformer = gpt
self.text_pos_embedding = text_pos_emb
self.embeddings = embeddings
self.lm_head = nn.Sequential(norm, linear)
self.kv_cache = kv_cache
# Model parallel
self.model_parallel = False
self.device_map = None
@ -75,6 +77,7 @@ class GPT2InferenceModel(GPT2PreTrainedModel):
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
token_type_ids = kwargs.get("token_type_ids", None)
if not self.kv_cache: past = None
# only last token for inputs_ids if past is defined in kwargs
if past:
input_ids = input_ids[:, -1].unsqueeze(-1)
@ -341,6 +344,19 @@ class UnifiedVoice(nn.Module):
for module in embeddings:
module.weight.data.normal_(mean=0.0, std=.02)
def post_init_gpt2_config(self, kv_cache=False):
seq_length = self.max_mel_tokens + self.max_text_tokens + 2
gpt_config = GPT2Config(vocab_size=self.max_mel_tokens,
n_positions=seq_length,
n_ctx=seq_length,
n_embd=self.model_dim,
n_layer=self.layers,
n_head=self.heads,
gradient_checkpointing=False,
use_cache=True)
self.inference_model = GPT2InferenceModel(gpt_config, self.gpt, self.mel_pos_embedding, self.mel_embedding, self.final_norm, self.mel_head, kv_cache=kv_cache)
self.gpt.wte = self.mel_embedding
def build_aligned_inputs_and_targets(self, input, start_token, stop_token):
inp = F.pad(input, (1,0), value=start_token)
tar = F.pad(input, (0,1), value=stop_token)
@ -461,17 +477,8 @@ class UnifiedVoice(nn.Module):
max_generate_length=None, typical_sampling=False, typical_mass=.9, **hf_generate_kwargs):
seq_length = self.max_mel_tokens + self.max_text_tokens + 2
if not hasattr(self, 'inference_model'):
# TODO: Decouple gpt_config from this inference model.
gpt_config = GPT2Config(vocab_size=self.max_mel_tokens,
n_positions=seq_length,
n_ctx=seq_length,
n_embd=self.model_dim,
n_layer=self.layers,
n_head=self.heads,
gradient_checkpointing=False,
use_cache=True)
self.inference_model = GPT2InferenceModel(gpt_config, self.gpt, self.mel_pos_embedding, self.mel_embedding, self.final_norm, self.mel_head)
self.gpt.wte = self.mel_embedding
self.post_init_gpt2_config(kv_cache=self.kv_cachepost_init_gpt2_config)
text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token)
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
@ -508,4 +515,4 @@ if __name__ == '__main__':
torch.tensor([32, 120]),
torch.randint(high=8192, size=(2,250)),
torch.tensor([250*256,195*256]))
gpt.text_forward(torch.randn(2,80,800), torch.randint(high=50, size=(2,80)), torch.tensor([32, 80]))
gpt.text_forward(torch.randn(2,80,800), torch.randint(high=50, size=(2,80)), torch.tensor([32, 80]))

@ -537,6 +537,14 @@ class GaussianDiffusion:
sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
return {"sample": sample, "pred_xstart": out["pred_xstart"]}
def sample_loop(self, *args, **kwargs):
s = self.sampler.lower()
if s == 'p':
return self.p_sample_loop(*args, **kwargs)
if s == 'ddim':
return self.ddim_sample_loop(*args, **kwargs)
else: raise RuntimeError("sampler not implemented")
def p_sample_loop(
self,
model,