forked from mrq/tortoise-tts
Update documentation, add optional verbosity
This commit is contained in:
parent
b2ffe02c2d
commit
a3daadd121
197
README.md
197
README.md
|
@ -1,77 +1,182 @@
|
|||
# Tortoise-TTS
|
||||
# TorToiSe
|
||||
|
||||
Tortoise TTS is an experimental text-to-speech program that uses recent machine learning techniques to generate
|
||||
high-quality speech samples.
|
||||
Tortoise is a text-to-speech program built with the following priorities:
|
||||
|
||||
1. Strong multi-voice capabilities.
|
||||
2. Highly realistic prosody and intonation.
|
||||
|
||||
This repo contains all the code needed to run Tortoise TTS in inference mode.
|
||||
|
||||
## What's in a name?
|
||||
|
||||
I'm naming my speech-related repos after Mojave desert flora and fauna. Tortoise is a bit tongue in cheek: this model
|
||||
is insanely slow. It leverages both an autoregressive speech alignment model and a diffusion model, both of which
|
||||
are known for their slow inference. It also performs CLIP sampling, which slows things down even further. You can
|
||||
expect ~5 seconds of speech to take ~30 seconds to produce on the latest hardware. Still, the results are pretty cool.
|
||||
is insanely slow. It leverages both an autoregressive decoder **and** a diffusion decoder; both known for their low
|
||||
sampling rates. On a K80, expect to generate a medium sized sentence every 2 minutes.
|
||||
|
||||
## What the heck is this?
|
||||
## Demos
|
||||
|
||||
Tortoise TTS is inspired by OpenAI's DALLE, applied to speech data. It is made up of 4 separate models that work together.
|
||||
These models are all derived from different repositories which are all linked. All the models have been modified
|
||||
for this use case (some substantially so).
|
||||
See [this page](http://nonint.com/static/tortoise_v2_examples.html) for a large list of example outputs.
|
||||
|
||||
First, an autoregressive transformer stack predicts discrete speech "tokens" given a text prompt. This model is very
|
||||
similar to the GPT model used by DALLE, except it operates on speech data.
|
||||
Based on: [GPT2 from Transformers](https://huggingface.co/docs/transformers/model_doc/gpt2)
|
||||
## Usage guide
|
||||
|
||||
Next, a CLIP model judges a batch of outputs from the autoregressive transformer against the provided text and stack
|
||||
ranks the outputs according to most probable. You could use greedy or beam-search decoding but in my experience CLIP
|
||||
decoding creates considerably better results.
|
||||
Based on [CLIP from lucidrains](https://github.com/lucidrains/DALLE-pytorch/blob/main/dalle_pytorch/dalle_pytorch.py)
|
||||
### Colab
|
||||
|
||||
Next, the speech "tokens" are decoded into a low-quality MEL spectrogram using a VQVAE.
|
||||
Based on [VQVAE2 by rosinality](https://github.com/rosinality/vq-vae-2-pytorch)
|
||||
Colab is the easiest way to try this out. I've put together a notebook you can use here:
|
||||
https://colab.research.google.com/drive/1wVVqUPqwiDBUVeWWOUNglpGhU3hg_cbR?usp=sharing
|
||||
|
||||
Finally, the output of the VQVAE is further decoded by a UNet diffusion model into raw audio, which can be placed in
|
||||
a wav file.
|
||||
Based on [ImprovedDiffusion by openai](https://github.com/openai/improved-diffusion)
|
||||
### Installation
|
||||
|
||||
## How do I use this?
|
||||
If you want to use this on your own computer, you must have an NVIDIA GPU. Installation:
|
||||
|
||||
Check out the colab: https://colab.research.google.com/drive/1wVVqUPqwiDBUVeWWOUNglpGhU3hg_cbR?usp=sharing
|
||||
|
||||
Or on a computer with a GPU (with >=16GB of VRAM):
|
||||
```shell
|
||||
git clone https://github.com/neonbjb/tortoise-tts.git
|
||||
cd tortoise-tts
|
||||
pip install -r requirements.txt
|
||||
python do_tts.py
|
||||
```
|
||||
|
||||
## Hand-picked TTS samples
|
||||
### do_tts.py
|
||||
|
||||
I generated ~250 samples from 23 text prompts and 8 voices. The text prompts have never been seen by the model. The
|
||||
voices were pulled from the training set.
|
||||
This script allows you to speak a single phrase with one or more voices.
|
||||
```shell
|
||||
python do_tts.py --text "I'm going to speak this" --voice dotrice --preset fast
|
||||
```
|
||||
|
||||
All of the samples can be found in the results/ folder of this repo. I handpicked a few to show what the model is capable of:
|
||||
### read.py
|
||||
|
||||
- [Atkins - Road not taken](results/favorites/atkins_road_not_taken.wav)
|
||||
- [Dotrice - Rolling Stone interview](results/favorites/dotrice_rollingstone.wav)
|
||||
- [Dotrice - 'Ornaments' from tacotron test set](results/favorites/dotrice_tacotron_samp1.wav)
|
||||
- [Kennard - 'Acute emotional intelligence' from tacotron test set](results/favorites/kennard_tacotron_samp2.wav)
|
||||
- [Mol - Because I could not stop for death](results/favorites/mol_dickenson.wav)
|
||||
- [Mol - Obama](results/favorites/mol_obama.wav)
|
||||
This script provides tools for reading large amounts of text.
|
||||
```shell
|
||||
python read.py --textfile <your text to be read> --voice dotrice
|
||||
```
|
||||
|
||||
Prosody is remarkably good for poetry, despite the fact that it was never trained on poetry.
|
||||
### API
|
||||
|
||||
## How do I train this?
|
||||
Tortoise can be used programmatically, like so:
|
||||
|
||||
Frankly - you don't. Building this model has been a labor of love for me, consuming most of my 6 RTX3090s worth of
|
||||
resources for the better part of 6 months. It uses a dataset I've gathered, refined and transcribed that consists of
|
||||
a lot of audio data which I cannot distribute because of copywrite or no open licenses.
|
||||
```python
|
||||
reference_clips = [utils.audio.load_audio(p, 22050) for p in clips_paths]
|
||||
tts = api.TextToSpeech()
|
||||
pcm_audio = tts.tts_with_preset("your text here", reference_clips, preset='fast')
|
||||
```
|
||||
|
||||
With that said, I'm willing to help you out if you really want to give it a shot. DM me.
|
||||
## Voice customization guide
|
||||
|
||||
Tortoise was specifically trained to be a multi-speaker model. It accomplishes this by consulting reference clips.
|
||||
|
||||
These reference clips are recordings of a speaker that you provide to guide speech generation. These clips are used to determine many properties of the output, such as the pitch and tone of the voice, speaking speed, and even speaking defects like a lisp or stuttering. The reference clip is also used to determine non-voice related aspects of the audio output like volume, background noise, recording quality and reverb.
|
||||
|
||||
### Provided voices
|
||||
|
||||
This repo comes with several pre-packaged voices. You will be familiar with many of them. :)
|
||||
|
||||
Most of the provided voices were not found in the training set. Experimentally, it seems that voices from the training set
|
||||
produce more realistic outputs then those outside of the training set. The following voices come from the training set:
|
||||
atkins, dotrice, grace, harris, kennard, lescault, mol, otto.
|
||||
|
||||
### Adding a new voice
|
||||
|
||||
To add new voices to Tortoise, you will need to do the following:
|
||||
|
||||
1. Gather audio clips of your speaker(s). Good sources are YouTube interviews (you can use youtube-dl to fetch the audio), audiobooks or podcasts. Guidelines for good clips are in the next section.
|
||||
2. Cut your clips into ~10 second segments. You want at least 3 clips. More is better, but I only experimented with up to 5 in my testing.
|
||||
3. Save the clips as a WAV file with floating point format and a 22,050 sample rate.
|
||||
4. Create a subdirectory in voices/
|
||||
5. Put your clips in that subdirectory.
|
||||
6. Run tortoise utilities with --voice=<your_subdirectory_name>.
|
||||
|
||||
### Picking good reference clips
|
||||
|
||||
As mentioned above, your reference clips have a profound impact on the output of Tortoise. Following are some tips for picking
|
||||
good clips:
|
||||
|
||||
1. Avoid clips with background music, noise or reverb. These clips were removed from the training dataset. Tortoise is unlikely to do well with them.
|
||||
2. Avoid speeches. These generally have distortion caused by the amplification system.
|
||||
3. Avoid clips from phone calls.
|
||||
4. Avoid clips that have excessive stuttering, stammering or words like "uh" or "like" in them.
|
||||
5. Try to find clips that are spoken in such a way as you wish your output to sound like. For example, if you want to hear your target voice read an audiobook, try to find clips of them reading a book.
|
||||
6. The text being spoken in the clips does not matter, but diverse text does seem to perform better.
|
||||
|
||||
## Advanced Usage
|
||||
|
||||
### Generation settings
|
||||
|
||||
Tortoise is primarily an autoregressive decoder model combined with a diffusion model. Both of these have a lot of knobs
|
||||
that can be turned that I've abstracted away for the sake of ease of use. I did this by generating thousands of clips using
|
||||
various permutations of the settings and using a metric for voice realism and intelligibility to measure their effects. I've
|
||||
set the defaults to the best overall settings I was able to find. For specific use-cases, it might be effective to play with
|
||||
these settings (and it's very likely that I missed something!)
|
||||
|
||||
These settings are not available in the normal scripts packaged with Tortoise. They are available, however, in the API. See
|
||||
```api.tts``` for a full list.
|
||||
|
||||
### Playing with the voice latent
|
||||
|
||||
Tortoise ingests reference clips by feeding them through individually through a small submodel that produces a point latent, then taking the mean of all of the produced latents. The experimentation I have done has indicated that these point latents are quite expressive, affecting
|
||||
everything from tone to speaking rate to speech abnormalities.
|
||||
|
||||
This lends itself to some neat tricks. For example, you can combine feed two different voices to tortoise and it will output what it thinks the "average" of those two voices sounds like. You could also theoretically build a small extension to Tortoise that gradually shifts the
|
||||
latent from one speaker to another, then apply it across a bit of spoken text (something I havent implemented yet, but might
|
||||
get to soon!) I am sure there are other interesting things that can be done here. Please let me know what you find!
|
||||
|
||||
### Send me feedback!
|
||||
|
||||
Probabilistic models like Tortoise are best thought of as an "augmented search" - in this case, through the space of possible
|
||||
utterances of a specific string of text. The impact of community involvement in perusing these spaces (such as is being done with
|
||||
GPT-3 or CLIP) has really surprised me. If you find something neat that you can do with Tortoise that isn't documented here,
|
||||
please report it to me! I would be glad to publish it to this page.
|
||||
|
||||
## Model architecture
|
||||
|
||||
Tortoise TTS is inspired by OpenAI's DALLE, applied to speech data and using a better decoder. It is made up of 5 separate
|
||||
models that work together. I've assembled a write-up of the system architecture here:
|
||||
[https://nonint.com/2022/04/25/tortoise-architectural-design-doc/](https://nonint.com/2022/04/25/tortoise-architectural-design-doc/)
|
||||
|
||||
## Training
|
||||
|
||||
These models were trained on my "homelab" server with 8 RTX 3090s over the course of several months. They were trained on a dataset consisting of
|
||||
~50k hours of speech data, most of which was transcribed by [ocotillo](http://www.github.com/neonbjb/ocotillo). Training was done on my own
|
||||
[DLAS](https://github.com/neonbjb/DL-Art-School) trainer.
|
||||
|
||||
I currently do not have plans to release the training configurations or methodology. See the next section..
|
||||
|
||||
## Ethical Considerations
|
||||
|
||||
Tortoise v2 works considerably better than I had planned. When I began hearing some of the outputs of the last few versions, I began
|
||||
wondering whether or not I had an ethically unsound project on my hands. The ways in which a voice-cloning text-to-speech system
|
||||
could be misused are many. It doesn't take much creativity to think up how.
|
||||
|
||||
After consulting with friends and family, I have decided to go forward with releasing this. Following are the reasons for this choice:
|
||||
|
||||
1. It is primarily good at reading books and speaking poetry. Other forms of speech do not work well.
|
||||
2. It was trained on a dataset which does not have the voices of public figures. While it will attempt to mimic these voices if they are provided as references, it does not do so in such a way that most humans would be fooled.
|
||||
3. The above points could likely be resolved by scaling up the model and the dataset. For this reason, I am currently withholding details on how I trained the model, pending community feedback.
|
||||
4. I am releasing a separate classifier model which will tell you whether a given audio clip was generated by Tortoise or not. See `tortoise-detect` above.
|
||||
5. If I, a tinkerer with a BS in computer science with a ~$15k computer can build this, then any motivated corporation or state can as well. I would prefer that it be in the open and everyone know the kinds of things ML can do.
|
||||
|
||||
### Diversity
|
||||
|
||||
The diversity expressed by ML models is strongly tied to the datasets they were trained on.
|
||||
|
||||
Tortoise was trained primarily on a dataset consisting of audiobooks. I made no effort to
|
||||
balance diversity in this dataset. For this reason, Tortoise will be particularly poor at generating the voices of minorities
|
||||
or of people who speak with strong accents.
|
||||
|
||||
## Looking forward
|
||||
|
||||
I'm not satisfied with this yet. Treat this as a "sneak peek" and check back in a couple of months. I think the concept
|
||||
is sound, but there are a few hurdles to overcome to get sample quality up. I have been doing major tweaks to the
|
||||
diffusion model and should have something new and much better soon.
|
||||
Tortoise v2 is about as good as I think I can do in the TTS world with the resources I have access to. A phenomenon that happens when
|
||||
training very large models is that as parameter count increases, the communication bandwidth needed to support distributed training
|
||||
of the model increases multiplicatively. On enterprise-grade hardware, this is not an issue: GPUs are attached together with
|
||||
exceptionally wide buses that can accommodate this bandwidth. I cannot afford enterprise hardware, though, so I am stuck.
|
||||
|
||||
I want to mention here
|
||||
that I think Tortoise could do be a **lot** better. The three major components of Tortoise are either vanilla Transformer Encoder stacks
|
||||
or Decoder stacks. Both of these types of models have a rich experimental history with scaling in the NLP realm. I see no reason
|
||||
to believe that the same is not true of TTS.
|
||||
|
||||
The largest model in Tortoise v2 is considerably smaller than GPT-2 large. It is 20x smaller that the original DALLE transformer.
|
||||
Imagine what a TTS model trained at or near GPT-3 or DALLE scale could achieve.
|
||||
|
||||
## Notice
|
||||
|
||||
Tortoise was built entirely by me using my own hardware. My employer was not involved in any facet of Tortoise's development.
|
||||
|
||||
If you use this repo or the ideas therein for your research, please cite it! A bibtex entree can be found in the right pane on GitHub.
|
65
api.py
65
api.py
|
@ -119,7 +119,7 @@ def fix_autoregressive_output(codes, stop_token, complain=True):
|
|||
return codes
|
||||
|
||||
|
||||
def do_spectrogram_diffusion(diffusion_model, diffuser, latents, conditioning_samples, temperature=1):
|
||||
def do_spectrogram_diffusion(diffusion_model, diffuser, latents, conditioning_samples, temperature=1, verbose=True):
|
||||
"""
|
||||
Uses the specified diffusion model to convert discrete codes into a spectrogram.
|
||||
"""
|
||||
|
@ -139,7 +139,8 @@ def do_spectrogram_diffusion(diffusion_model, diffuser, latents, conditioning_sa
|
|||
|
||||
noise = torch.randn(output_shape, device=latents.device) * temperature
|
||||
mel = diffuser.p_sample_loop(diffusion_model, output_shape, noise=noise,
|
||||
model_kwargs={'precomputed_aligned_embeddings': precomputed_embeddings})
|
||||
model_kwargs={'precomputed_aligned_embeddings': precomputed_embeddings},
|
||||
progress=verbose)
|
||||
return denormalize_tacotron_mel(mel)[:,:,:output_seq_len]
|
||||
|
||||
|
||||
|
@ -203,14 +204,59 @@ class TextToSpeech:
|
|||
kwargs.update(presets[preset])
|
||||
return self.tts(text, voice_samples, **kwargs)
|
||||
|
||||
def tts(self, text, voice_samples, k=1,
|
||||
def tts(self, text, voice_samples, k=1, verbose=True,
|
||||
# autoregressive generation parameters follow
|
||||
num_autoregressive_samples=512, temperature=.8, length_penalty=1, repetition_penalty=2.0, top_p=.8, max_mel_tokens=500,
|
||||
typical_sampling=False, typical_mass=.9,
|
||||
# CLVP & CVVP parameters
|
||||
clvp_cvvp_slider=.5,
|
||||
# diffusion generation parameters follow
|
||||
diffusion_iterations=100, cond_free=True, cond_free_k=2, diffusion_temperature=1.0,
|
||||
**hf_generate_kwargs):
|
||||
"""
|
||||
Produces an audio clip of the given text being spoken with the given reference voice.
|
||||
:param text: Text to be spoken.
|
||||
:param voice_samples: List of 2 or more ~10 second reference clips which should be torch tensors containing 22.05kHz waveform data.
|
||||
:param k: The number of returned clips. The most likely (as determined by Tortoises' CLVP and CVVP models) clips are returned.
|
||||
:param verbose: Whether or not to print log messages indicating the progress of creating a clip. Default=true.
|
||||
~~AUTOREGRESSIVE KNOBS~~
|
||||
:param num_autoregressive_samples: Number of samples taken from the autoregressive model, all of which are filtered using CLVP+CVVP.
|
||||
As Tortoise is a probabilistic model, more samples means a higher probability of creating something "great".
|
||||
:param temperature: The softmax temperature of the autoregressive model.
|
||||
:param length_penalty: A length penalty applied to the autoregressive decoder. Higher settings causes the model to produce more terse outputs.
|
||||
:param repetition_penalty: A penalty that prevents the autoregressive decoder from repeating itself during decoding. Can be used to reduce the incidence
|
||||
of long silences or "uhhhhhhs", etc.
|
||||
:param top_p: P value used in nucleus sampling. (0,1]. Lower values mean the decoder produces more "likely" (aka boring) outputs.
|
||||
:param max_mel_tokens: Restricts the output length. (0,600] integer. Each unit is 1/20 of a second.
|
||||
:param typical_sampling: Turns typical sampling on or off. This sampling mode is discussed in this paper: https://arxiv.org/abs/2202.00666
|
||||
I was interested in the premise, but the results were not as good as I was hoping. This is off by default, but
|
||||
could use some tuning.
|
||||
:param typical_mass: The typical_mass parameter from the typical_sampling algorithm.
|
||||
~~CLVP-CVVP KNOBS~~
|
||||
:param clvp_cvvp_slider: Controls the influence of the CLVP and CVVP models in selecting the best output from the autoregressive model.
|
||||
[0,1]. Values closer to 1 will cause Tortoise to emit clips that follow the text more. Values closer to
|
||||
0 will cause Tortoise to emit clips that more closely follow the reference clip (e.g. the voice sounds more
|
||||
similar).
|
||||
~~DIFFUSION KNOBS~~
|
||||
:param diffusion_iterations: Number of diffusion steps to perform. [0,4000]. More steps means the network has more chances to iteratively refine
|
||||
the output, which should theoretically mean a higher quality output. Generally a value above 250 is not noticeably better,
|
||||
however.
|
||||
:param cond_free: Whether or not to perform conditioning-free diffusion. Conditioning-free diffusion performs two forward passes for
|
||||
each diffusion step: one with the outputs of the autoregressive model and one with no conditioning priors. The output
|
||||
of the two is blended according to the cond_free_k value below. Conditioning-free diffusion is the real deal, and
|
||||
dramatically improves realism.
|
||||
:param cond_free_k: Knob that determines how to balance the conditioning free signal with the conditioning-present signal. [0,inf].
|
||||
As cond_free_k increases, the output becomes dominated by the conditioning-free signal.
|
||||
Formula is: output=cond_present_output*(cond_free_k+1)-cond_absenct_output*cond_free_k
|
||||
:param diffusion_temperature: Controls the variance of the noise fed into the diffusion model. [0,1]. Values at 0
|
||||
are the "mean" prediction of the diffusion network and will sound bland and smeared.
|
||||
~~OTHER STUFF~~
|
||||
:param hf_generate_kwargs: The huggingface Transformers generate API is used for the autoregressive transformer.
|
||||
Extra keyword args fed to this function get forwarded directly to that API. Documentation
|
||||
here: https://huggingface.co/docs/transformers/internal/generation_utils
|
||||
:return: Generated audio clip(s) as a torch tensor. Shape 1,S if k=1 else, (k,1,S) where S is the sample length.
|
||||
Sample rate is 24kHz.
|
||||
"""
|
||||
text = torch.IntTensor(self.tokenizer.encode(text)).unsqueeze(0).cuda()
|
||||
text = F.pad(text, (0, 1)) # This may not be necessary.
|
||||
|
||||
|
@ -229,7 +275,9 @@ class TextToSpeech:
|
|||
stop_mel_token = self.autoregressive.stop_mel_token
|
||||
calm_token = 83 # This is the token for coding silence, which is fixed in place with "fix_autoregressive_output"
|
||||
self.autoregressive = self.autoregressive.cuda()
|
||||
for b in tqdm(range(num_batches)):
|
||||
if verbose:
|
||||
print("Generating autoregressive samples..")
|
||||
for b in tqdm(range(num_batches), disable=not verbose):
|
||||
codes = self.autoregressive.inference_speech(conds, text,
|
||||
do_sample=True,
|
||||
top_p=top_p,
|
||||
|
@ -247,7 +295,9 @@ class TextToSpeech:
|
|||
clip_results = []
|
||||
self.clvp = self.clvp.cuda()
|
||||
self.cvvp = self.cvvp.cuda()
|
||||
for batch in samples:
|
||||
if verbose:
|
||||
print("Computing best candidates using CLVP and CVVP")
|
||||
for batch in tqdm(samples, disable=not verbose):
|
||||
for i in range(batch.shape[0]):
|
||||
batch[i] = fix_autoregressive_output(batch[i], stop_mel_token)
|
||||
clvp = self.clvp(text.repeat(batch.shape[0], 1), batch, return_loss=False)
|
||||
|
@ -272,7 +322,8 @@ class TextToSpeech:
|
|||
return_latent=True, clip_inputs=False)
|
||||
self.autoregressive = self.autoregressive.cpu()
|
||||
|
||||
print("Performing vocoding..")
|
||||
if verbose:
|
||||
print("Transforming autoregressive outputs into audio..")
|
||||
wav_candidates = []
|
||||
self.diffusion = self.diffusion.cuda()
|
||||
self.vocoder = self.vocoder.cuda()
|
||||
|
@ -291,7 +342,7 @@ class TextToSpeech:
|
|||
latents = latents[:, :k]
|
||||
break
|
||||
|
||||
mel = do_spectrogram_diffusion(self.diffusion, diffuser, latents, voice_samples, temperature=diffusion_temperature)
|
||||
mel = do_spectrogram_diffusion(self.diffusion, diffuser, latents, voice_samples, temperature=diffusion_temperature, verbose=verbose)
|
||||
wav = self.vocoder.inference(mel)
|
||||
wav_candidates.append(wav.cpu())
|
||||
self.diffusion = self.diffusion.cpu()
|
||||
|
|
|
@ -17,13 +17,105 @@
|
|||
"accelerator": "GPU"
|
||||
},
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"Welcome to Tortoise! 🐢🐢🐢🐢\n",
|
||||
"\n",
|
||||
"Before you begin, I **strongly** recommend you turn on a GPU runtime.\n",
|
||||
"\n",
|
||||
"There's a reason this is called \"Tortoise\" - this model takes up to a minute to perform inference for a single sentence on a GPU. Expect waits on the order of hours on a CPU."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "_pIZ3ZXNp7cf"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "JrK20I32grP6"
|
||||
"id": "JrK20I32grP6",
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"outputId": "44f55dca-5d0a-405e-a4cc-54bc8e16b780"
|
||||
},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stdout",
|
||||
"text": [
|
||||
"Cloning into 'tortoise-tts'...\n",
|
||||
"remote: Enumerating objects: 736, done.\u001b[K\n",
|
||||
"remote: Counting objects: 100% (23/23), done.\u001b[K\n",
|
||||
"remote: Compressing objects: 100% (15/15), done.\u001b[K\n",
|
||||
"remote: Total 736 (delta 10), reused 20 (delta 8), pack-reused 713\u001b[K\n",
|
||||
"Receiving objects: 100% (736/736), 348.62 MiB | 24.08 MiB/s, done.\n",
|
||||
"Resolving deltas: 100% (161/161), done.\n",
|
||||
"/content/tortoise-tts\n",
|
||||
"Requirement already satisfied: torch in /usr/local/lib/python3.7/dist-packages (from -r requirements.txt (line 1)) (1.10.0+cu111)\n",
|
||||
"Requirement already satisfied: torchaudio in /usr/local/lib/python3.7/dist-packages (from -r requirements.txt (line 2)) (0.10.0+cu111)\n",
|
||||
"Collecting rotary_embedding_torch\n",
|
||||
" Downloading rotary_embedding_torch-0.1.5-py3-none-any.whl (4.1 kB)\n",
|
||||
"Collecting transformers\n",
|
||||
" Downloading transformers-4.18.0-py3-none-any.whl (4.0 MB)\n",
|
||||
"\u001b[K |████████████████████████████████| 4.0 MB 5.3 MB/s \n",
|
||||
"\u001b[?25hCollecting tokenizers\n",
|
||||
" Downloading tokenizers-0.12.1-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (6.6 MB)\n",
|
||||
"\u001b[K |████████████████████████████████| 6.6 MB 31.3 MB/s \n",
|
||||
"\u001b[?25hRequirement already satisfied: inflect in /usr/local/lib/python3.7/dist-packages (from -r requirements.txt (line 6)) (2.1.0)\n",
|
||||
"Collecting progressbar\n",
|
||||
" Downloading progressbar-2.5.tar.gz (10 kB)\n",
|
||||
"Collecting einops\n",
|
||||
" Downloading einops-0.4.1-py3-none-any.whl (28 kB)\n",
|
||||
"Collecting unidecode\n",
|
||||
" Downloading Unidecode-1.3.4-py3-none-any.whl (235 kB)\n",
|
||||
"\u001b[K |████████████████████████████████| 235 kB 44.3 MB/s \n",
|
||||
"\u001b[?25hCollecting entmax\n",
|
||||
" Downloading entmax-1.0.tar.gz (7.2 kB)\n",
|
||||
"Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from torch->-r requirements.txt (line 1)) (4.1.1)\n",
|
||||
"Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.7/dist-packages (from transformers->-r requirements.txt (line 4)) (4.64.0)\n",
|
||||
"Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.7/dist-packages (from transformers->-r requirements.txt (line 4)) (21.3)\n",
|
||||
"Collecting sacremoses\n",
|
||||
" Downloading sacremoses-0.0.49-py3-none-any.whl (895 kB)\n",
|
||||
"\u001b[K |████████████████████████████████| 895 kB 36.6 MB/s \n",
|
||||
"\u001b[?25hCollecting huggingface-hub<1.0,>=0.1.0\n",
|
||||
" Downloading huggingface_hub-0.5.1-py3-none-any.whl (77 kB)\n",
|
||||
"\u001b[K |████████████████████████████████| 77 kB 6.3 MB/s \n",
|
||||
"\u001b[?25hRequirement already satisfied: filelock in /usr/local/lib/python3.7/dist-packages (from transformers->-r requirements.txt (line 4)) (3.6.0)\n",
|
||||
"Collecting pyyaml>=5.1\n",
|
||||
" Downloading PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (596 kB)\n",
|
||||
"\u001b[K |████████████████████████████████| 596 kB 38.9 MB/s \n",
|
||||
"\u001b[?25hRequirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.7/dist-packages (from transformers->-r requirements.txt (line 4)) (1.21.6)\n",
|
||||
"Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from transformers->-r requirements.txt (line 4)) (2.23.0)\n",
|
||||
"Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.7/dist-packages (from transformers->-r requirements.txt (line 4)) (2019.12.20)\n",
|
||||
"Requirement already satisfied: importlib-metadata in /usr/local/lib/python3.7/dist-packages (from transformers->-r requirements.txt (line 4)) (4.11.3)\n",
|
||||
"Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging>=20.0->transformers->-r requirements.txt (line 4)) (3.0.8)\n",
|
||||
"Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata->transformers->-r requirements.txt (line 4)) (3.8.0)\n",
|
||||
"Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->transformers->-r requirements.txt (line 4)) (1.24.3)\n",
|
||||
"Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->transformers->-r requirements.txt (line 4)) (3.0.4)\n",
|
||||
"Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->transformers->-r requirements.txt (line 4)) (2.10)\n",
|
||||
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->transformers->-r requirements.txt (line 4)) (2021.10.8)\n",
|
||||
"Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from sacremoses->transformers->-r requirements.txt (line 4)) (1.15.0)\n",
|
||||
"Requirement already satisfied: joblib in /usr/local/lib/python3.7/dist-packages (from sacremoses->transformers->-r requirements.txt (line 4)) (1.1.0)\n",
|
||||
"Requirement already satisfied: click in /usr/local/lib/python3.7/dist-packages (from sacremoses->transformers->-r requirements.txt (line 4)) (7.1.2)\n",
|
||||
"Building wheels for collected packages: progressbar, entmax\n",
|
||||
" Building wheel for progressbar (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
|
||||
" Created wheel for progressbar: filename=progressbar-2.5-py3-none-any.whl size=12082 sha256=bb7d90605d0bf4d89aedc46bd8ed39538f55e00ee70fa382c1af81f142f08fa8\n",
|
||||
" Stored in directory: /root/.cache/pip/wheels/f0/fd/1f/3e35ed57e94cd8ced38dd46771f1f0f94f65fec548659ed855\n",
|
||||
" Building wheel for entmax (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
|
||||
" Created wheel for entmax: filename=entmax-1.0-py3-none-any.whl size=11015 sha256=5e2cf723e790ec941984d2030eb3231e1ae3ce75231709391a13edcd2bfb4770\n",
|
||||
" Stored in directory: /root/.cache/pip/wheels/f7/e8/0d/acc29c2f66e69a1f42483347fa8545c293dec12325ee161716\n",
|
||||
"Successfully built progressbar entmax\n",
|
||||
"Installing collected packages: pyyaml, tokenizers, sacremoses, huggingface-hub, einops, unidecode, transformers, rotary-embedding-torch, progressbar, entmax\n",
|
||||
" Attempting uninstall: pyyaml\n",
|
||||
" Found existing installation: PyYAML 3.13\n",
|
||||
" Uninstalling PyYAML-3.13:\n",
|
||||
" Successfully uninstalled PyYAML-3.13\n",
|
||||
"Successfully installed einops-0.4.1 entmax-1.0 huggingface-hub-0.5.1 progressbar-2.5 pyyaml-6.0 rotary-embedding-torch-0.1.5 sacremoses-0.0.49 tokenizers-0.12.1 transformers-4.18.0 unidecode-1.3.4\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"!git clone https://github.com/neonbjb/tortoise-tts.git\n",
|
||||
"%cd tortoise-tts\n",
|
||||
|
@ -38,58 +130,156 @@
|
|||
"import torchaudio\n",
|
||||
"import torch.nn as nn\n",
|
||||
"import torch.nn.functional as F\n",
|
||||
"from tqdm import tqdm\n",
|
||||
"\n",
|
||||
"from utils.tokenizer import VoiceBpeTokenizer\n",
|
||||
"from models.discrete_diffusion_vocoder import DiscreteDiffusionVocoder\n",
|
||||
"from models.text_voice_clip import VoiceCLIP\n",
|
||||
"from models.dvae import DiscreteVAE\n",
|
||||
"from models.autoregressive import UnifiedVoice\n",
|
||||
"from api import TextToSpeech\n",
|
||||
"from utils.audio import load_audio, get_voices\n",
|
||||
"\n",
|
||||
"# These have some fairly interesting code that is hidden in the colab. Consider checking it out.\n",
|
||||
"from do_tts import download_models, load_discrete_vocoder_diffuser, load_conditioning, fix_autoregressive_output, do_spectrogram_diffusion"
|
||||
"# This will download all the models used by Tortoise from the HF hub.\n",
|
||||
"tts = TextToSpeech()"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "Gen09NM4hONQ"
|
||||
"id": "Gen09NM4hONQ",
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"outputId": "35c1fb4b-5998-4e75-9ec9-29521b301db6"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stdout",
|
||||
"text": [
|
||||
"Downloading autoregressive.pth from https://huggingface.co/jbetker/tortoise-tts-v2/resolve/hf/.models/autoregressive.pth...\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stderr",
|
||||
"text": [
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stdout",
|
||||
"text": [
|
||||
"Done.\n",
|
||||
"Downloading clvp.pth from https://huggingface.co/jbetker/tortoise-tts-v2/resolve/hf/.models/clvp.pth...\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stderr",
|
||||
"text": [
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stdout",
|
||||
"text": [
|
||||
"Done.\n",
|
||||
"Downloading cvvp.pth from https://huggingface.co/jbetker/tortoise-tts-v2/resolve/hf/.models/cvvp.pth...\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stderr",
|
||||
"text": [
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stdout",
|
||||
"text": [
|
||||
"Done.\n",
|
||||
"Downloading diffusion_decoder.pth from https://huggingface.co/jbetker/tortoise-tts-v2/resolve/hf/.models/diffusion_decoder.pth...\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stderr",
|
||||
"text": [
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stdout",
|
||||
"text": [
|
||||
"Done.\n",
|
||||
"Downloading vocoder.pth from https://huggingface.co/jbetker/tortoise-tts-v2/resolve/hf/.models/vocoder.pth...\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stderr",
|
||||
"text": [
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stdout",
|
||||
"text": [
|
||||
"Done.\n",
|
||||
"Removing weight norm...\n"
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"# Download pretrained models and set up pretrained voice bank. Feel free to upload and add your own voices here.\n",
|
||||
"# To do so, upload two WAV files cropped to 5-10 seconds of someone speaking.\n",
|
||||
"download_models()\n",
|
||||
"preselected_cond_voices = {\n",
|
||||
" # Male voices\n",
|
||||
" 'dotrice': ['voices/dotrice/1.wav', 'voices/dotrice/2.wav'],\n",
|
||||
" 'harris': ['voices/harris/1.wav', 'voices/harris/2.wav'],\n",
|
||||
" 'lescault': ['voices/lescault/1.wav', 'voices/lescault/2.wav'],\n",
|
||||
" 'otto': ['voices/otto/1.wav', 'voices/otto/2.wav'],\n",
|
||||
" # Female voices\n",
|
||||
" 'atkins': ['voices/atkins/1.wav', 'voices/atkins/2.wav'],\n",
|
||||
" 'grace': ['voices/grace/1.wav', 'voices/grace/2.wav'],\n",
|
||||
" 'kennard': ['voices/kennard/1.wav', 'voices/kennard/2.wav'],\n",
|
||||
" 'mol': ['voices/mol/1.wav', 'voices/mol/2.wav'],\n",
|
||||
" }"
|
||||
"# List all the voices available. These are just some random clips I've gathered\n",
|
||||
"# from the internet as well as a few voices from the training dataset.\n",
|
||||
"# Feel free to add your own clips to the voices/ folder.\n",
|
||||
"%ls voices"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "SSleVnRAiEE2"
|
||||
"id": "SSleVnRAiEE2",
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"outputId": "e1eb09e2-1b68-4f81-b679-edb97538da39"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stdout",
|
||||
"text": [
|
||||
"\u001b[0m\u001b[01;34mangelina_jolie\u001b[0m/ \u001b[01;34mhalle_barry\u001b[0m/ \u001b[01;34mlj\u001b[0m/ \u001b[01;34msamuel_jackson\u001b[0m/\n",
|
||||
"\u001b[01;34matkins\u001b[0m/ \u001b[01;34mharris\u001b[0m/ \u001b[01;34mmol\u001b[0m/ \u001b[01;34msigourney_weaver\u001b[0m/\n",
|
||||
"\u001b[01;34mcarlin\u001b[0m/ \u001b[01;34mhenry_cavill\u001b[0m/ \u001b[01;34mmorgan_freeman\u001b[0m/ \u001b[01;34mtom_hanks\u001b[0m/\n",
|
||||
"\u001b[01;34mdaniel_craig\u001b[0m/ \u001b[01;34mjennifer_lawrence\u001b[0m/ \u001b[01;34mmyself\u001b[0m/ \u001b[01;34mwilliam_shatner\u001b[0m/\n",
|
||||
"\u001b[01;34mdotrice\u001b[0m/ \u001b[01;34mjohn_krasinski\u001b[0m/ \u001b[01;34motto\u001b[0m/\n",
|
||||
"\u001b[01;34memma_stone\u001b[0m/ \u001b[01;34mkennard\u001b[0m/ \u001b[01;34mpatrick_stewart\u001b[0m/\n",
|
||||
"\u001b[01;34mgrace\u001b[0m/ \u001b[01;34mlescault\u001b[0m/ \u001b[01;34mrobert_deniro\u001b[0m/\n"
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"# This is the text that will be spoken.\n",
|
||||
"text = \"And took the other as just as fair, and having perhaps the better claim, because it was grassy and wanted wear.\"\n",
|
||||
"# This is the voice that will speak it.\n",
|
||||
"voice = 'atkins'\n",
|
||||
"# This is the number of samples we will generate from the DALLE-style model. More will produce better results, but will take longer to produce.\n",
|
||||
"# I don't recommend going less than 128.\n",
|
||||
"num_autoregressive_samples = 128"
|
||||
"text = \"Joining two modalities results in a surprising increase in generalization! What would happen if we combined them all?\"\n",
|
||||
"\n",
|
||||
"# Here's something for the poetically inclined.. (set text=)\n",
|
||||
"\"\"\"\n",
|
||||
"Then took the other, as just as fair,\n",
|
||||
"And having perhaps the better claim,\n",
|
||||
"Because it was grassy and wanted wear;\n",
|
||||
"Though as for that the passing there\n",
|
||||
"Had worn them really about the same,\"\"\"\n",
|
||||
"\n",
|
||||
"# Pick one of the voices from above\n",
|
||||
"voice = 'dotrice'\n",
|
||||
"# Pick a \"preset mode\" to determine quality. Options: {\"ultra_fast\", \"fast\" (default), \"standard\", \"high_quality\"}. See docs in api.py\n",
|
||||
"preset = \"fast\""
|
||||
],
|
||||
"metadata": {
|
||||
"id": "bt_aoxONjfL2"
|
||||
|
@ -100,149 +290,106 @@
|
|||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"# Prepare data.\n",
|
||||
"tokenizer = VoiceBpeTokenizer()\n",
|
||||
"text = torch.IntTensor(tokenizer.encode(text)).unsqueeze(0).cuda()\n",
|
||||
"text = F.pad(text, (0,1)) # This may not be necessary.\n",
|
||||
"cond_paths = preselected_cond_voices[voice]\n",
|
||||
"# Fetch the voice references and forward execute!\n",
|
||||
"voices = get_voices()\n",
|
||||
"cond_paths = voices[voice]\n",
|
||||
"conds = []\n",
|
||||
"for cond_path in cond_paths:\n",
|
||||
" c, cond_wav = load_conditioning(cond_path)\n",
|
||||
" c = load_audio(cond_path, 22050)\n",
|
||||
" conds.append(c)\n",
|
||||
"conds = torch.stack(conds, dim=1) # And just use the last cond_wav for the diffusion model."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "KEXOKjIvn6NW"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"# Load the autoregressive model.\n",
|
||||
"autoregressive = UnifiedVoice(max_mel_tokens=300, max_text_tokens=200, max_conditioning_inputs=2, layers=30, model_dim=1024,\n",
|
||||
" heads=16, number_text_tokens=256, start_text_token=255, checkpointing=False, train_solo_embeddings=False).cuda().eval()\n",
|
||||
"autoregressive.load_state_dict(torch.load('.models/autoregressive.pth'))\n",
|
||||
"stop_mel_token = autoregressive.stop_mel_token"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "Z15xFT_uhP8v"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"# Perform inference with the autoregressive model, generating num_autoregressive_samples\n",
|
||||
"with torch.no_grad():\n",
|
||||
" samples = []\n",
|
||||
" for b in tqdm(range(num_autoregressive_samples // 16)):\n",
|
||||
" codes = autoregressive.inference_speech(conds, text, num_beams=1, repetition_penalty=1.0, do_sample=True, top_k=50, top_p=.95,\n",
|
||||
" temperature=.9, num_return_sequences=16, length_penalty=1)\n",
|
||||
" padding_needed = 250 - codes.shape[1]\n",
|
||||
" codes = F.pad(codes, (0, padding_needed), value=stop_mel_token)\n",
|
||||
" samples.append(codes)\n",
|
||||
"\n",
|
||||
"# Delete model weights to conserve memory.\n",
|
||||
"del autoregressive"
|
||||
"gen = tts.tts_with_preset(text, conds, preset)\n",
|
||||
"torchaudio.save('generated.wav', gen.squeeze(0).cpu(), 24000)"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "xajqWiEik-j0"
|
||||
"id": "KEXOKjIvn6NW",
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"outputId": "7977bfd7-9fbc-41f7-d3ac-25fd4e350049"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stderr",
|
||||
"text": [
|
||||
"100%|██████████| 6/6 [01:18<00:00, 13.11s/it]\n",
|
||||
"/usr/local/lib/python3.7/dist-packages/torch/utils/checkpoint.py:25: UserWarning: None of the inputs have requires_grad=True. Gradients will be None\n",
|
||||
" warnings.warn(\"None of the inputs have requires_grad=True. Gradients will be None\")\n",
|
||||
"/content/tortoise-tts/models/autoregressive.py:359: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').\n",
|
||||
" mel_lengths = wav_lengths // self.mel_length_compression\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stdout",
|
||||
"text": [
|
||||
"Performing vocoding..\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stderr",
|
||||
"text": [
|
||||
"100%|██████████| 32/32 [00:16<00:00, 1.94it/s]\n"
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"# Load the CLIP model.\n",
|
||||
"clip = VoiceCLIP(dim_text=512, dim_speech=512, dim_latent=512, num_text_tokens=256, text_enc_depth=8, text_seq_len=120, text_heads=8,\n",
|
||||
" num_speech_tokens=8192, speech_enc_depth=10, speech_heads=8, speech_seq_len=250).cuda().eval()\n",
|
||||
"clip.load_state_dict(torch.load('.models/clip.pth'))"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "KNgYSyuyliMs"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"# Use the CLIP model to select the best autoregressive output to match the given text.\n",
|
||||
"clip_results = []\n",
|
||||
"with torch.no_grad():\n",
|
||||
" for batch in samples:\n",
|
||||
" for i in range(batch.shape[0]):\n",
|
||||
" batch[i] = fix_autoregressive_output(batch[i], stop_mel_token)\n",
|
||||
" text = text[:, :120] # Ugly hack to fix the fact that I didn't train CLIP to handle long enough text.\n",
|
||||
" clip_results.append(clip(text.repeat(batch.shape[0], 1),\n",
|
||||
" torch.full((batch.shape[0],), fill_value=text.shape[1]-1, dtype=torch.long, device='cuda'),\n",
|
||||
" batch, torch.full((batch.shape[0],), fill_value=batch.shape[1]*1024, dtype=torch.long, device='cuda'),\n",
|
||||
" return_loss=False))\n",
|
||||
" clip_results = torch.cat(clip_results, dim=0)\n",
|
||||
" samples = torch.cat(samples, dim=0)\n",
|
||||
" best_results = samples[torch.topk(clip_results, k=1).indices]\n",
|
||||
"# You can add as many conditioning voices as you want together. Combining\n",
|
||||
"# clips from multiple voices takes the mean of the latent space for all\n",
|
||||
"# voices. This creates a novel voice that is a combination of the two inputs.\n",
|
||||
"#\n",
|
||||
"# Lets see what it would sound like if Picard and Kirk had a kid with a penchant for philosophy:\n",
|
||||
"conds = []\n",
|
||||
"for v in ['patrick_stewart', 'william_shatner']:\n",
|
||||
" cond_paths = voices[v]\n",
|
||||
" for cond_path in cond_paths:\n",
|
||||
" c = load_audio(cond_path, 22050)\n",
|
||||
" conds.append(c)\n",
|
||||
"\n",
|
||||
"# Save samples to CPU memory, delete clip to conserve memory.\n",
|
||||
"samples = samples.cpu()\n",
|
||||
"del clip"
|
||||
"gen = tts.tts_with_preset(\"They used to say that if man was meant to fly, he’d have wings. But he did fly. He discovered he had to.\", conds, preset)\n",
|
||||
"torchaudio.save('captain_kirkard.wav', gen.squeeze(0).cpu(), 24000)"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "DDXkM0lclp4U"
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "fYTk8KUezUr5",
|
||||
"outputId": "8a07f251-c90f-4e6a-c204-132b737dfff8"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"# Load the DVAE and diffusion model.\n",
|
||||
"dvae = DiscreteVAE(positional_dims=1, channels=80, hidden_dim=512, num_resnet_blocks=3, codebook_dim=512, num_tokens=8192, num_layers=2,\n",
|
||||
" record_codes=True, kernel_size=3, use_transposed_convs=False).cuda().eval()\n",
|
||||
"dvae.load_state_dict(torch.load('.models/dvae.pth'), strict=False)\n",
|
||||
"diffusion = DiscreteDiffusionVocoder(model_channels=128, dvae_dim=80, channel_mult=[1, 1, 1.5, 2, 3, 4, 6, 8, 8, 8, 8], num_res_blocks=[1, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1],\n",
|
||||
" spectrogram_conditioning_resolutions=[2,512], attention_resolutions=[512,1024], num_heads=4, kernel_size=3, scale_factor=2,\n",
|
||||
" conditioning_inputs_provided=True, time_embed_dim_multiplier=4).cuda().eval()\n",
|
||||
"diffusion.load_state_dict(torch.load('.models/diffusion.pth'))\n",
|
||||
"diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=100)"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "97acSnBal8Q2"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"# Decode the (best) discrete sequence created by the autoregressive model.\n",
|
||||
"with torch.no_grad():\n",
|
||||
" for b in range(best_results.shape[0]):\n",
|
||||
" code = best_results[b].unsqueeze(0)\n",
|
||||
" wav = do_spectrogram_diffusion(diffusion, dvae, diffuser, code, cond_wav, spectrogram_compression_factor=256, mean=True)\n",
|
||||
" torchaudio.save(f'{voice}_{b}.wav', wav.squeeze(0).cpu(), 22050)"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "HEDABTrdl_kM"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"# Listen to your text! (told you that'd take a long time..)\n",
|
||||
"from IPython.display import Audio\n",
|
||||
"Audio(data=wav.squeeze(0).cpu().numpy(), rate=22050)"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "EyHmcdqBmSvf"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stderr",
|
||||
"text": [
|
||||
"100%|██████████| 6/6 [01:45<00:00, 17.62s/it]\n",
|
||||
"/usr/local/lib/python3.7/dist-packages/torch/utils/checkpoint.py:25: UserWarning: None of the inputs have requires_grad=True. Gradients will be None\n",
|
||||
" warnings.warn(\"None of the inputs have requires_grad=True. Gradients will be None\")\n",
|
||||
"/content/tortoise-tts/models/autoregressive.py:359: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').\n",
|
||||
" mel_lengths = wav_lengths // self.mel_length_compression\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stdout",
|
||||
"text": [
|
||||
"Performing vocoding..\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stderr",
|
||||
"text": [
|
||||
"100%|██████████| 32/32 [00:16<00:00, 2.00it/s]\n"
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
|
@ -605,7 +605,7 @@ class GaussianDiffusion:
|
|||
img = th.randn(*shape, device=device)
|
||||
indices = list(range(self.num_timesteps))[::-1]
|
||||
|
||||
for i in tqdm(indices):
|
||||
for i in tqdm(indices, disable=not progress):
|
||||
t = th.tensor([i] * shape[0], device=device)
|
||||
with th.no_grad():
|
||||
out = self.p_sample(
|
||||
|
@ -774,7 +774,7 @@ class GaussianDiffusion:
|
|||
# Lazy import so that we don't depend on tqdm.
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
indices = tqdm(indices)
|
||||
indices = tqdm(indices, disable=not progress)
|
||||
|
||||
for i in indices:
|
||||
t = th.tensor([i] * shape[0], device=device)
|
||||
|
|
Loading…
Reference in New Issue
Block a user