small cleanups

This commit is contained in:
mrq 2024-05-04 22:37:22 -05:00
parent 8aa1b2dabf
commit 33b7f81b94
6 changed files with 16 additions and 11 deletions

View File

@ -44,7 +44,7 @@ Training is very dependent on:
### Pre-Processed Dataset ### Pre-Processed Dataset
A "libre" dataset can be found [here](https://huggingface.co/ecker/vall-e) under `data.tar.gz`. A "libre" dataset utilizing EnCodec quantized audio can be found [here](https://huggingface.co/ecker/vall-e) under `data.tar.gz`.
A script to setup a proper environment and train can be invoked with `./scripts/setup-training.sh` A script to setup a proper environment and train can be invoked with `./scripts/setup-training.sh`
@ -52,6 +52,8 @@ A script to setup a proper environment and train can be invoked with `./scripts/
> **Note** Preparing a dataset is a bit messy. > **Note** Preparing a dataset is a bit messy.
If you already have a dataset you want, for example your own large corpus, or for finetuning, you can use your own dataset instead.
0. Set up a `venv` with `https://github.com/m-bain/whisperX/`. 0. Set up a `venv` with `https://github.com/m-bain/whisperX/`.
+ At the moment only WhisperX is utilized. Using other variants like `faster-whisper` is an exercise left to the user at the moment. + At the moment only WhisperX is utilized. Using other variants like `faster-whisper` is an exercise left to the user at the moment.
+ It's recommended to use a dedicated virtualenv specifically for transcribing, as WhisperX will break a few dependencies. + It's recommended to use a dedicated virtualenv specifically for transcribing, as WhisperX will break a few dependencies.
@ -84,7 +86,7 @@ Two dataset formats are supported:
- this will shove everything into a single HDF5 file and store some metadata alongside (for now, the symbol map generated, and text/audio lengths) - this will shove everything into a single HDF5 file and store some metadata alongside (for now, the symbol map generated, and text/audio lengths)
- be sure to also define `use_hdf5` in your config YAML. - be sure to also define `use_hdf5` in your config YAML.
### Initializing Training ### Training
For single GPUs, simply running `python3 -m vall_e.train yaml="./training/config.yaml`. For single GPUs, simply running `python3 -m vall_e.train yaml="./training/config.yaml`.
@ -114,6 +116,7 @@ Keep in mind that creature comforts like distributed training or `float16` train
Unfortunately, efforts to train a *good* foundational model seems entirely predicated on a good dataset. My dataset might be too fouled with: Unfortunately, efforts to train a *good* foundational model seems entirely predicated on a good dataset. My dataset might be too fouled with:
* too short utterances: trying to extrapolate longer contexts seems to utterly fall apart from just the `text` being too long. * too short utterances: trying to extrapolate longer contexts seems to utterly fall apart from just the `text` being too long.
+ It might help to, instead, initially train with smaller utterances, train for two epochs, then increase the each sample length.
* too tightly trimmed utterances: there being little to no space at the start and end might harm associating `<s>` and `</s>` tokens with empty utterances. * too tightly trimmed utterances: there being little to no space at the start and end might harm associating `<s>` and `</s>` tokens with empty utterances.
* a poorly mapped phoneme mapping: I naively crafted my own phoneme mapping, where a HuggingFace tokenizer might supply a better token mapping. * a poorly mapped phoneme mapping: I naively crafted my own phoneme mapping, where a HuggingFace tokenizer might supply a better token mapping.
@ -129,7 +132,7 @@ As the core of VALL-E makes use of a language model, various LLM architectures c
* `llama`: using HF transformer's LLaMa implementation for its attention-based transformer, boasting RoPE and other improvements. * `llama`: using HF transformer's LLaMa implementation for its attention-based transformer, boasting RoPE and other improvements.
* `mixtral`: using HF transformer's Mixtral implementation for its attention-based transformer, also utilizing its MoE implementation. * `mixtral`: using HF transformer's Mixtral implementation for its attention-based transformer, also utilizing its MoE implementation.
* `bitnet`: using [this](https://github.com/kyegomez/BitNet/) implementation of BitNet's transformer. * `bitnet`: using [this](https://github.com/kyegomez/BitNet/) implementation of BitNet's transformer.
- Setting `bitsandbytes.bitnet=True` will make use of BitNet's linear implementation. - Setting `cfg.optimizers.bitnet=True` will make use of BitNet's linear implementation.
If you're training a true foundational model, consider which backend you want to use the most. `llama` backends can benefit from all the additional tech with it, while exotic ones like `retnet` or `bitnet` can't at the moment, but may leverage experimental gains. If you're training a true foundational model, consider which backend you want to use the most. `llama` backends can benefit from all the additional tech with it, while exotic ones like `retnet` or `bitnet` can't at the moment, but may leverage experimental gains.
@ -173,6 +176,7 @@ And some experimental sampling flags you can use too (your mileage will ***defin
* clean up the README, and document, document, document onto the wiki. * clean up the README, and document, document, document onto the wiki.
* extend to ~~multiple languages ([VALL-E X](https://arxiv.org/abs/2303.03926)) and~~ addditional tasks ([SpeechX](https://arxiv.org/abs/2308.06873)). * extend to ~~multiple languages ([VALL-E X](https://arxiv.org/abs/2303.03926)) and~~ addditional tasks ([SpeechX](https://arxiv.org/abs/2308.06873)).
- training additional tasks needs the SpeechX implementation to be reworked. - training additional tasks needs the SpeechX implementation to be reworked.
- this requires a good foundational model before extending it to transfer tasks onto.
* improve throughput (despite peaking at 120it/s): * improve throughput (despite peaking at 120it/s):
- properly utilize RetNet's recurrent forward / chunkwise forward passes (does not seem to want to work no matter how the model is trained). - properly utilize RetNet's recurrent forward / chunkwise forward passes (does not seem to want to work no matter how the model is trained).
- utilize an approach similar to [FasterDecoding/Medusa](https://github.com/FasterDecoding/Medusa/) with additional heads for decoding N+1, N+2, N+3 AR tokens - utilize an approach similar to [FasterDecoding/Medusa](https://github.com/FasterDecoding/Medusa/) with additional heads for decoding N+1, N+2, N+3 AR tokens

View File

@ -44,7 +44,7 @@ setup(
"encodec>=0.1.1", "encodec>=0.1.1",
"phonemizer>=2.1.0", "phonemizer>=2.1.0",
"matplotlib>=3.6.0", "matplotlib>=3.6.0",
"numpy==1.23.0", "numpy",
"omegaconf==2.0.6", "omegaconf==2.0.6",
"tqdm>=4.64.1", "tqdm>=4.64.1",
"humanize>=4.4.0", "humanize>=4.4.0",
@ -60,6 +60,8 @@ setup(
"h5py", "h5py",
"torchscale @ git+https://git.ecker.tech/mrq/torchscale", "torchscale @ git+https://git.ecker.tech/mrq/torchscale",
"prodigyopt @ git+https://github.com/konstmish/prodigy", "prodigyopt @ git+https://github.com/konstmish/prodigy",
"descript-audio-codec",
], ],
url="https://git.ecker.tech/mrq/vall-e", url="https://git.ecker.tech/mrq/vall-e",
) )

View File

@ -512,10 +512,10 @@ class Inference:
use_encodec: bool = True use_encodec: bool = True
use_dac: bool = True use_dac: bool = True
# shit that doesn't work
recurrent_chunk_size: int = 0 recurrent_chunk_size: int = 0
recurrent_forward: bool = False recurrent_forward: bool = False
@cached_property @cached_property
def dtype(self): def dtype(self):
if self.weight_dtype == "float16": if self.weight_dtype == "float16":
@ -562,7 +562,7 @@ class Config(_Config):
tokenizer: str = "./tokenizer.json" tokenizer: str = "./tokenizer.json"
sample_rate: int = 24_000 sample_rate: int = 24_000
variable_sample_rate: bool = True variable_sample_rate: bool = True # for DAC, this will override the model automatically resampling to 44KHz.
@property @property
def distributed(self): def distributed(self):

View File

@ -370,9 +370,8 @@ class Dataset(_Dataset):
# shuffle it up a bit # shuffle it up a bit
prom_length = 0 prom_length = 0
if cfg.experimental: if cfg.experimental and False:
trim_length = random.randint(cfg.dataset.frames_per_second * 3, cfg.dataset.frames_per_second * 6) # [3 seconds, 6 seconds] trim_length = max(2, int(np.random.normal(loc=5, scale=1.25) * cfg.dataset.frames_per_second))
#trim_length = max(2, int(np.random.normal(loc=5, scale=1.25) * cfg.dataset.frames_per_second))
else: else:
trim_length = int(cfg.dataset.prompt_duration * cfg.dataset.frames_per_second) + random.randint(-cfg.dataset.frames_per_second, cfg.dataset.frames_per_second) trim_length = int(cfg.dataset.prompt_duration * cfg.dataset.frames_per_second) + random.randint(-cfg.dataset.frames_per_second, cfg.dataset.frames_per_second)

View File

@ -125,7 +125,7 @@ class AR_NAR(Base):
if n_levels == self.n_resp_levels: if n_levels == self.n_resp_levels:
# might be better to have this decided on the dataloader level # might be better to have this decided on the dataloader level
if cfg.experimental and False: if cfg.experimental:
# makes higher levels less likely # makes higher levels less likely
def generate( lo=0, hi=8 ): def generate( lo=0, hi=8 ):
index = lo index = lo

View File

@ -197,7 +197,7 @@ try:
except Exception as e: except Exception as e:
print("Error creating `LLamaXformersAttention`:", e) print("Error creating `LLamaXformersAttention`:", e)
def replace_attention( model, impl, verbose=Valse ): def replace_attention( model, impl, verbose=False ):
device = next(model.parameters()).device device = next(model.parameters()).device
dtype = next(model.parameters()).dtype dtype = next(model.parameters()).dtype
attentions = [k.split('.') for k, m in model.named_modules() if isinstance(m, LlamaAttention)] attentions = [k.split('.') for k, m in model.named_modules() if isinstance(m, LlamaAttention)]