# VALL'E An unofficial PyTorch implementation of [VALL-E](https://valle-demo.github.io/), utilizing the [EnCodec](https://github.com/facebookresearch/encodec) encoder/decoder. ## Requirements Besides a working PyTorch environment, the only hard requirement is [`espeak-ng`](https://github.com/espeak-ng/espeak-ng/) for phonemizing text: - Linux users can consult their package managers on installing `espeak`/`espeak-ng`. - Windows users are required to install [`espeak-ng`](https://github.com/espeak-ng/espeak-ng/releases/tag/1.51#Assets). + additionally, you may be required to set the `PHONEMIZER_ESPEAK_LIBRARY` environment variable to specify the path to `libespeak-ng.dll`. - In the future, an internal homebrew to replace this *would* be fantastic. ## Install Simply run `pip install git+https://git.ecker.tech/mrq/vall-e` or `pip install git+https://github.com/e-c-k-e-r/vall-e`. I've tested this repo under Python versions `3.10.9`, `3.11.3`, and `3.12.3`. ## Pre-Trained Model > [!NOTE] > Pre-Trained weights aren't up to par as a pure zero-shot model at the moment, but are fine for finetuning / LoRAs. My pre-trained weights can be acquired from [here](https://huggingface.co/ecker/vall-e). A script to setup a proper environment and download the weights can be invoked with `./scripts/setup.sh` ## Train Training is very dependent on: * the quality of your dataset. * how much data you have. * the bandwidth you quantized your audio to. * the underlying model architecture used. ### Try Me To quickly test if a configuration works, you can run `python -m vall_e.models.ar_nar --yaml="./data/config.yaml"`; a small trainer will overfit a provided utterance. ### Leverage Your Own Dataset If you already have a dataset you want, for example, your own large corpus or for finetuning, you can use your own dataset instead. 0. Set up a `venv` with `https://github.com/m-bain/whisperX/`. + At the moment only WhisperX is utilized. Using other variants like `faster-whisper` is an exercise left to the user at the moment. + It's recommended to use a dedicated virtualenv specifically for transcribing, as WhisperX will break a few dependencies. + The following command should work: ``` python3 -m venv venv-whisper source ./venv-whisper/bin/activate pip3 install torch torchvision torchaudio pip3 install git+https://github.com/m-bain/whisperX/ ``` 1. Populate your source voices under `./voices/{group name}/{speaker name}/`. 2. Run `python3 ./scripts/transcribe_dataset.py`. This will generate a transcription with timestamps for your dataset. + If you're interested in using a different model, edit the script's `model_name` and `batch_size` variables. 3. Run `python3 ./scripts/process_dataset.py`. This will phonemize the transcriptions and quantize the audio. + If you're using a Descript-Audio-Codec based model, ensure to set the sample rate and audio backend accordingly. 4. Copy `./data/config.yaml` to `./training/config.yaml`. Customize the training configuration and populate your `dataset.training` list with the values stored under `./training/dataset_list.json`. + Refer to `./vall_e/config.py` for additional configuration details. ### Dataset Formats Two dataset formats are supported: * the standard way: - data is stored under `./training/data/{group}/{speaker}/{id}.{enc|dac}` as a NumPy file, where `enc` is for the EnCodec/Vocos backend, and `dac` for the Descript-Audio-Codec backend. - it is *highly* recommended to generate metadata to speed up dataset pre-load with `python3 -m vall_e.data --yaml="./training/config.yaml" --action=metadata` * using an HDF5 dataset: - you can convert from the standard way with the following command: `python3 -m vall_e.data --yaml="./training/config.yaml"` (metadata for dataset pre-load is generated alongside HDF5 creation) - this will shove everything into a single HDF5 file and store some metadata alongside (for now, the symbol map generated, and text/audio lengths) - be sure to also define `use_hdf5` in your config YAML. ### Training For single GPUs, simply running `python3 -m vall_e.train --yaml="./training/config.yaml`. For multiple GPUs, or exotic distributed training: * with `deepspeed` backends, simply running `deepspeed --module vall_e.train --yaml="./training/config.yaml"` should handle the gory details. * with `local` backends, simply run `torchrun --nnodes=1 --nproc-per-node={NUMOFGPUS} -m vall_e.train --yaml="./training/config.yaml"` You can enter `save` to save the state at any time, or `quit` to save and quit training. The `lr` will also let you adjust the learning rate on the fly. For example: `lr 1.0e-3` will set the learning rate to `0.001`. ### Finetuning Finetuning can be done by training the full model, or using a LoRA. Finetuning the full model is done the same way as training a model, but be sure to have the weights in the correct spot, as if you're loading them for inferencing. For training a LoRA, add the following block to your `config.yaml`: ``` loras: - name : "arbitrary name" # whatever you want rank: 128 # dimensionality of the LoRA alpha: 128 # scaling factor of the LoRA training: True ``` And that's it. Training of the LoRA is done with the same command. Depending on the rank and alpha specified, the loss may be higher than it should, as the LoRA weights are initialized to appropriately random values. I found `rank` and `alpha` of 128 works fine. To export your LoRA weights, run `python3 -m vall_e.export --lora --yaml="./training/config.yaml"`. You *should* be able to have the LoRA weights loaded from a training checkpoint automagically for inferencing, but export them just to be safe. ### Plotting Metrics Included is a helper script to parse the training metrics. Simply invoke it with, for example: `python3 -m vall_e.plot --yaml="./training/config.yaml"` You can specify what X and Y labels you want to plot against by passing `--xs tokens_processed --ys loss stats.acc` ### Notices #### Training Under Windows As training under `deepspeed` and Windows is not (easily) supported, under your `config.yaml`, simply change `trainer.backend` to `local` to use the local training backend. Creature comforts like `float16`, `amp`, and multi-GPU training *should* work under the `local` backend, but extensive testing still needs to be done to ensure it all functions. #### Backend Architectures As the core of VALL-E makes use of a language model, various LLM architectures can be supported and slotted in. Currently supported LLM architectures: * `llama`: using HF transformer's LLaMa implementation for its attention-based transformer, boasting RoPE and other improvements. + I aim to utilize this for the foundational model, as I get to leverage a bunch of things tailored for LLaMA (and converting to them is rather easy). * `mixtral`: using HF transformer's Mixtral implementation for its attention-based transformer, also utilizing its MoE implementation. * `bitnet`: using [this](https://github.com/kyegomez/BitNet/) implementation of BitNet's transformer. - Setting `cfg.optimizers.bitnet=True` will make use of BitNet's linear implementation. * `transformer`: a basic attention-based transformer implementation, with attention heads + feed forwards. * `retnet`: using [TorchScale's RetNet](https://github.com/microsoft/torchscale/blob/main/torchscale/architecture/retnet.py) implementation, a retention-based approach can be used instead. - Its implementation for MoE can also be utilized. * `retnet-hf`: using [syncdoth/RetNet](https://github.com/syncdoth/RetNet) with a HuggingFace-compatible RetNet model - has an inference penality, and MoE is not implemented. * `mamba`: using [state-spaces/mamba](https://github.com/state-spaces/mamba) (needs to mature) - ***really hard*** to have a unified AR and NAR model - inference penalty makes it a really hard sell, despite the loss already being a low 3 after a short amount of samples processed For audio backends: * [`encodec`](https://github.com/facebookresearch/encodec): a tried-and-tested EnCodec to encode/decode audio. * [`vocos`](https://huggingface.co/charactr/vocos-encodec-24khz): a higher quality EnCodec decoder. - encoding audio will use the `encodec` backend automagically, as there's no EnCodec encoder under `vocos` * [`descript-audio-codec`](https://github.com/descriptinc/descript-audio-codec): boasts better compression and quality - models at 24KHz + 8kbps will NOT converge in any manner. - models at 44KHz + 8kbps seems harder to model its "language", and the NAR side of the model suffers greatly. `llama`-based models also support different attention backends: * `math`: torch's SDPA's `math` implementation * `mem_efficient`: torch's SDPA's memory efficient (`xformers` adjacent) implementation * `flash`: torch's SDPA's flash attention implementation * `xformers`: [facebookresearch/xformers](https://github.com/facebookresearch/xformers/)'s memory efficient attention * `auto`: determine the best fit from the above * `sdpa`: integrated `LlamaSdpaAttention` attention model * `flash_attention_2`: integrated `LlamaFlashAttetion2` attention model The wide support for various backends is solely while I try and figure out which is the "best" for a core foundation model. ## Export To export the models, run: `python -m vall_e.export --yaml=./training/config.yaml`. This will export the latest checkpoints, for example, under `./training/ckpt/ar+nar-retnet-8/fp32.pth`, to be loaded on any system with PyTorch, and will include additional metadata, such as the symmap used, and training stats. ## Synthesis To synthesize speech: `python -m vall_e