diff --git a/README.md b/README.md index eeaed13..0063ddf 100755 --- a/README.md +++ b/README.md @@ -44,7 +44,7 @@ Training is very dependent on: ### Pre-Processed Dataset -A "libre" dataset can be found [here](https://huggingface.co/ecker/vall-e) under `data.tar.gz`. +A "libre" dataset utilizing EnCodec quantized audio can be found [here](https://huggingface.co/ecker/vall-e) under `data.tar.gz`. A script to setup a proper environment and train can be invoked with `./scripts/setup-training.sh` @@ -52,6 +52,8 @@ A script to setup a proper environment and train can be invoked with `./scripts/ > **Note** Preparing a dataset is a bit messy. +If you already have a dataset you want, for example your own large corpus, or for finetuning, you can use your own dataset instead. + 0. Set up a `venv` with `https://github.com/m-bain/whisperX/`. + At the moment only WhisperX is utilized. Using other variants like `faster-whisper` is an exercise left to the user at the moment. + It's recommended to use a dedicated virtualenv specifically for transcribing, as WhisperX will break a few dependencies. @@ -84,7 +86,7 @@ Two dataset formats are supported: - this will shove everything into a single HDF5 file and store some metadata alongside (for now, the symbol map generated, and text/audio lengths) - be sure to also define `use_hdf5` in your config YAML. -### Initializing Training +### Training For single GPUs, simply running `python3 -m vall_e.train yaml="./training/config.yaml`. @@ -114,6 +116,7 @@ Keep in mind that creature comforts like distributed training or `float16` train Unfortunately, efforts to train a *good* foundational model seems entirely predicated on a good dataset. My dataset might be too fouled with: * too short utterances: trying to extrapolate longer contexts seems to utterly fall apart from just the `text` being too long. + + It might help to, instead, initially train with smaller utterances, train for two epochs, then increase the each sample length. * too tightly trimmed utterances: there being little to no space at the start and end might harm associating `` and `` tokens with empty utterances. * a poorly mapped phoneme mapping: I naively crafted my own phoneme mapping, where a HuggingFace tokenizer might supply a better token mapping. @@ -129,7 +132,7 @@ As the core of VALL-E makes use of a language model, various LLM architectures c * `llama`: using HF transformer's LLaMa implementation for its attention-based transformer, boasting RoPE and other improvements. * `mixtral`: using HF transformer's Mixtral implementation for its attention-based transformer, also utilizing its MoE implementation. * `bitnet`: using [this](https://github.com/kyegomez/BitNet/) implementation of BitNet's transformer. - - Setting `bitsandbytes.bitnet=True` will make use of BitNet's linear implementation. + - Setting `cfg.optimizers.bitnet=True` will make use of BitNet's linear implementation. If you're training a true foundational model, consider which backend you want to use the most. `llama` backends can benefit from all the additional tech with it, while exotic ones like `retnet` or `bitnet` can't at the moment, but may leverage experimental gains. @@ -173,6 +176,7 @@ And some experimental sampling flags you can use too (your mileage will ***defin * clean up the README, and document, document, document onto the wiki. * extend to ~~multiple languages ([VALL-E X](https://arxiv.org/abs/2303.03926)) and~~ addditional tasks ([SpeechX](https://arxiv.org/abs/2308.06873)). - training additional tasks needs the SpeechX implementation to be reworked. + - this requires a good foundational model before extending it to transfer tasks onto. * improve throughput (despite peaking at 120it/s): - properly utilize RetNet's recurrent forward / chunkwise forward passes (does not seem to want to work no matter how the model is trained). - utilize an approach similar to [FasterDecoding/Medusa](https://github.com/FasterDecoding/Medusa/) with additional heads for decoding N+1, N+2, N+3 AR tokens diff --git a/setup.py b/setup.py index b5ece1a..98c9d50 100755 --- a/setup.py +++ b/setup.py @@ -44,7 +44,7 @@ setup( "encodec>=0.1.1", "phonemizer>=2.1.0", "matplotlib>=3.6.0", - "numpy==1.23.0", + "numpy", "omegaconf==2.0.6", "tqdm>=4.64.1", "humanize>=4.4.0", @@ -60,6 +60,8 @@ setup( "h5py", "torchscale @ git+https://git.ecker.tech/mrq/torchscale", "prodigyopt @ git+https://github.com/konstmish/prodigy", + + "descript-audio-codec", ], url="https://git.ecker.tech/mrq/vall-e", ) diff --git a/vall_e/config.py b/vall_e/config.py index 5900b54..3a06730 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -512,10 +512,10 @@ class Inference: use_encodec: bool = True use_dac: bool = True + # shit that doesn't work recurrent_chunk_size: int = 0 recurrent_forward: bool = False - @cached_property def dtype(self): if self.weight_dtype == "float16": @@ -562,7 +562,7 @@ class Config(_Config): tokenizer: str = "./tokenizer.json" sample_rate: int = 24_000 - variable_sample_rate: bool = True + variable_sample_rate: bool = True # for DAC, this will override the model automatically resampling to 44KHz. @property def distributed(self): diff --git a/vall_e/data.py b/vall_e/data.py index 16c9f13..aa5f2a4 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -370,9 +370,8 @@ class Dataset(_Dataset): # shuffle it up a bit prom_length = 0 - if cfg.experimental: - trim_length = random.randint(cfg.dataset.frames_per_second * 3, cfg.dataset.frames_per_second * 6) # [3 seconds, 6 seconds] - #trim_length = max(2, int(np.random.normal(loc=5, scale=1.25) * cfg.dataset.frames_per_second)) + if cfg.experimental and False: + trim_length = max(2, int(np.random.normal(loc=5, scale=1.25) * cfg.dataset.frames_per_second)) else: trim_length = int(cfg.dataset.prompt_duration * cfg.dataset.frames_per_second) + random.randint(-cfg.dataset.frames_per_second, cfg.dataset.frames_per_second) diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index b931e17..3164b36 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -125,7 +125,7 @@ class AR_NAR(Base): if n_levels == self.n_resp_levels: # might be better to have this decided on the dataloader level - if cfg.experimental and False: + if cfg.experimental: # makes higher levels less likely def generate( lo=0, hi=8 ): index = lo diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 28f2359..c764557 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -197,7 +197,7 @@ try: except Exception as e: print("Error creating `LLamaXformersAttention`:", e) -def replace_attention( model, impl, verbose=Valse ): +def replace_attention( model, impl, verbose=False ): device = next(model.parameters()).device dtype = next(model.parameters()).dtype attentions = [k.split('.') for k, m in model.named_modules() if isinstance(m, LlamaAttention)]