small cleanups

This commit is contained in:
mrq 2024-05-04 22:37:22 -05:00
parent 8aa1b2dabf
commit 33b7f81b94
6 changed files with 16 additions and 11 deletions

View File

@ -44,7 +44,7 @@ Training is very dependent on:
### Pre-Processed Dataset
A "libre" dataset can be found [here](https://huggingface.co/ecker/vall-e) under `data.tar.gz`.
A "libre" dataset utilizing EnCodec quantized audio can be found [here](https://huggingface.co/ecker/vall-e) under `data.tar.gz`.
A script to setup a proper environment and train can be invoked with `./scripts/setup-training.sh`
@ -52,6 +52,8 @@ A script to setup a proper environment and train can be invoked with `./scripts/
> **Note** Preparing a dataset is a bit messy.
If you already have a dataset you want, for example your own large corpus, or for finetuning, you can use your own dataset instead.
0. Set up a `venv` with `https://github.com/m-bain/whisperX/`.
+ At the moment only WhisperX is utilized. Using other variants like `faster-whisper` is an exercise left to the user at the moment.
+ It's recommended to use a dedicated virtualenv specifically for transcribing, as WhisperX will break a few dependencies.
@ -84,7 +86,7 @@ Two dataset formats are supported:
- this will shove everything into a single HDF5 file and store some metadata alongside (for now, the symbol map generated, and text/audio lengths)
- be sure to also define `use_hdf5` in your config YAML.
### Initializing Training
### Training
For single GPUs, simply running `python3 -m vall_e.train yaml="./training/config.yaml`.
@ -114,6 +116,7 @@ Keep in mind that creature comforts like distributed training or `float16` train
Unfortunately, efforts to train a *good* foundational model seems entirely predicated on a good dataset. My dataset might be too fouled with:
* too short utterances: trying to extrapolate longer contexts seems to utterly fall apart from just the `text` being too long.
+ It might help to, instead, initially train with smaller utterances, train for two epochs, then increase the each sample length.
* too tightly trimmed utterances: there being little to no space at the start and end might harm associating `<s>` and `</s>` tokens with empty utterances.
* a poorly mapped phoneme mapping: I naively crafted my own phoneme mapping, where a HuggingFace tokenizer might supply a better token mapping.
@ -129,7 +132,7 @@ As the core of VALL-E makes use of a language model, various LLM architectures c
* `llama`: using HF transformer's LLaMa implementation for its attention-based transformer, boasting RoPE and other improvements.
* `mixtral`: using HF transformer's Mixtral implementation for its attention-based transformer, also utilizing its MoE implementation.
* `bitnet`: using [this](https://github.com/kyegomez/BitNet/) implementation of BitNet's transformer.
- Setting `bitsandbytes.bitnet=True` will make use of BitNet's linear implementation.
- Setting `cfg.optimizers.bitnet=True` will make use of BitNet's linear implementation.
If you're training a true foundational model, consider which backend you want to use the most. `llama` backends can benefit from all the additional tech with it, while exotic ones like `retnet` or `bitnet` can't at the moment, but may leverage experimental gains.
@ -173,6 +176,7 @@ And some experimental sampling flags you can use too (your mileage will ***defin
* clean up the README, and document, document, document onto the wiki.
* extend to ~~multiple languages ([VALL-E X](https://arxiv.org/abs/2303.03926)) and~~ addditional tasks ([SpeechX](https://arxiv.org/abs/2308.06873)).
- training additional tasks needs the SpeechX implementation to be reworked.
- this requires a good foundational model before extending it to transfer tasks onto.
* improve throughput (despite peaking at 120it/s):
- properly utilize RetNet's recurrent forward / chunkwise forward passes (does not seem to want to work no matter how the model is trained).
- utilize an approach similar to [FasterDecoding/Medusa](https://github.com/FasterDecoding/Medusa/) with additional heads for decoding N+1, N+2, N+3 AR tokens

View File

@ -44,7 +44,7 @@ setup(
"encodec>=0.1.1",
"phonemizer>=2.1.0",
"matplotlib>=3.6.0",
"numpy==1.23.0",
"numpy",
"omegaconf==2.0.6",
"tqdm>=4.64.1",
"humanize>=4.4.0",
@ -60,6 +60,8 @@ setup(
"h5py",
"torchscale @ git+https://git.ecker.tech/mrq/torchscale",
"prodigyopt @ git+https://github.com/konstmish/prodigy",
"descript-audio-codec",
],
url="https://git.ecker.tech/mrq/vall-e",
)

View File

@ -512,10 +512,10 @@ class Inference:
use_encodec: bool = True
use_dac: bool = True
# shit that doesn't work
recurrent_chunk_size: int = 0
recurrent_forward: bool = False
@cached_property
def dtype(self):
if self.weight_dtype == "float16":
@ -562,7 +562,7 @@ class Config(_Config):
tokenizer: str = "./tokenizer.json"
sample_rate: int = 24_000
variable_sample_rate: bool = True
variable_sample_rate: bool = True # for DAC, this will override the model automatically resampling to 44KHz.
@property
def distributed(self):

View File

@ -370,9 +370,8 @@ class Dataset(_Dataset):
# shuffle it up a bit
prom_length = 0
if cfg.experimental:
trim_length = random.randint(cfg.dataset.frames_per_second * 3, cfg.dataset.frames_per_second * 6) # [3 seconds, 6 seconds]
#trim_length = max(2, int(np.random.normal(loc=5, scale=1.25) * cfg.dataset.frames_per_second))
if cfg.experimental and False:
trim_length = max(2, int(np.random.normal(loc=5, scale=1.25) * cfg.dataset.frames_per_second))
else:
trim_length = int(cfg.dataset.prompt_duration * cfg.dataset.frames_per_second) + random.randint(-cfg.dataset.frames_per_second, cfg.dataset.frames_per_second)

View File

@ -125,7 +125,7 @@ class AR_NAR(Base):
if n_levels == self.n_resp_levels:
# might be better to have this decided on the dataloader level
if cfg.experimental and False:
if cfg.experimental:
# makes higher levels less likely
def generate( lo=0, hi=8 ):
index = lo

View File

@ -197,7 +197,7 @@ try:
except Exception as e:
print("Error creating `LLamaXformersAttention`:", e)
def replace_attention( model, impl, verbose=Valse ):
def replace_attention( model, impl, verbose=False ):
device = next(model.parameters()).device
dtype = next(model.parameters()).dtype
attentions = [k.split('.') for k, m in model.named_modules() if isinstance(m, LlamaAttention)]