small cleanups
This commit is contained in:
parent
8aa1b2dabf
commit
33b7f81b94
10
README.md
10
README.md
|
@ -44,7 +44,7 @@ Training is very dependent on:
|
||||||
|
|
||||||
### Pre-Processed Dataset
|
### Pre-Processed Dataset
|
||||||
|
|
||||||
A "libre" dataset can be found [here](https://huggingface.co/ecker/vall-e) under `data.tar.gz`.
|
A "libre" dataset utilizing EnCodec quantized audio can be found [here](https://huggingface.co/ecker/vall-e) under `data.tar.gz`.
|
||||||
|
|
||||||
A script to setup a proper environment and train can be invoked with `./scripts/setup-training.sh`
|
A script to setup a proper environment and train can be invoked with `./scripts/setup-training.sh`
|
||||||
|
|
||||||
|
@ -52,6 +52,8 @@ A script to setup a proper environment and train can be invoked with `./scripts/
|
||||||
|
|
||||||
> **Note** Preparing a dataset is a bit messy.
|
> **Note** Preparing a dataset is a bit messy.
|
||||||
|
|
||||||
|
If you already have a dataset you want, for example your own large corpus, or for finetuning, you can use your own dataset instead.
|
||||||
|
|
||||||
0. Set up a `venv` with `https://github.com/m-bain/whisperX/`.
|
0. Set up a `venv` with `https://github.com/m-bain/whisperX/`.
|
||||||
+ At the moment only WhisperX is utilized. Using other variants like `faster-whisper` is an exercise left to the user at the moment.
|
+ At the moment only WhisperX is utilized. Using other variants like `faster-whisper` is an exercise left to the user at the moment.
|
||||||
+ It's recommended to use a dedicated virtualenv specifically for transcribing, as WhisperX will break a few dependencies.
|
+ It's recommended to use a dedicated virtualenv specifically for transcribing, as WhisperX will break a few dependencies.
|
||||||
|
@ -84,7 +86,7 @@ Two dataset formats are supported:
|
||||||
- this will shove everything into a single HDF5 file and store some metadata alongside (for now, the symbol map generated, and text/audio lengths)
|
- this will shove everything into a single HDF5 file and store some metadata alongside (for now, the symbol map generated, and text/audio lengths)
|
||||||
- be sure to also define `use_hdf5` in your config YAML.
|
- be sure to also define `use_hdf5` in your config YAML.
|
||||||
|
|
||||||
### Initializing Training
|
### Training
|
||||||
|
|
||||||
For single GPUs, simply running `python3 -m vall_e.train yaml="./training/config.yaml`.
|
For single GPUs, simply running `python3 -m vall_e.train yaml="./training/config.yaml`.
|
||||||
|
|
||||||
|
@ -114,6 +116,7 @@ Keep in mind that creature comforts like distributed training or `float16` train
|
||||||
|
|
||||||
Unfortunately, efforts to train a *good* foundational model seems entirely predicated on a good dataset. My dataset might be too fouled with:
|
Unfortunately, efforts to train a *good* foundational model seems entirely predicated on a good dataset. My dataset might be too fouled with:
|
||||||
* too short utterances: trying to extrapolate longer contexts seems to utterly fall apart from just the `text` being too long.
|
* too short utterances: trying to extrapolate longer contexts seems to utterly fall apart from just the `text` being too long.
|
||||||
|
+ It might help to, instead, initially train with smaller utterances, train for two epochs, then increase the each sample length.
|
||||||
* too tightly trimmed utterances: there being little to no space at the start and end might harm associating `<s>` and `</s>` tokens with empty utterances.
|
* too tightly trimmed utterances: there being little to no space at the start and end might harm associating `<s>` and `</s>` tokens with empty utterances.
|
||||||
* a poorly mapped phoneme mapping: I naively crafted my own phoneme mapping, where a HuggingFace tokenizer might supply a better token mapping.
|
* a poorly mapped phoneme mapping: I naively crafted my own phoneme mapping, where a HuggingFace tokenizer might supply a better token mapping.
|
||||||
|
|
||||||
|
@ -129,7 +132,7 @@ As the core of VALL-E makes use of a language model, various LLM architectures c
|
||||||
* `llama`: using HF transformer's LLaMa implementation for its attention-based transformer, boasting RoPE and other improvements.
|
* `llama`: using HF transformer's LLaMa implementation for its attention-based transformer, boasting RoPE and other improvements.
|
||||||
* `mixtral`: using HF transformer's Mixtral implementation for its attention-based transformer, also utilizing its MoE implementation.
|
* `mixtral`: using HF transformer's Mixtral implementation for its attention-based transformer, also utilizing its MoE implementation.
|
||||||
* `bitnet`: using [this](https://github.com/kyegomez/BitNet/) implementation of BitNet's transformer.
|
* `bitnet`: using [this](https://github.com/kyegomez/BitNet/) implementation of BitNet's transformer.
|
||||||
- Setting `bitsandbytes.bitnet=True` will make use of BitNet's linear implementation.
|
- Setting `cfg.optimizers.bitnet=True` will make use of BitNet's linear implementation.
|
||||||
|
|
||||||
If you're training a true foundational model, consider which backend you want to use the most. `llama` backends can benefit from all the additional tech with it, while exotic ones like `retnet` or `bitnet` can't at the moment, but may leverage experimental gains.
|
If you're training a true foundational model, consider which backend you want to use the most. `llama` backends can benefit from all the additional tech with it, while exotic ones like `retnet` or `bitnet` can't at the moment, but may leverage experimental gains.
|
||||||
|
|
||||||
|
@ -173,6 +176,7 @@ And some experimental sampling flags you can use too (your mileage will ***defin
|
||||||
* clean up the README, and document, document, document onto the wiki.
|
* clean up the README, and document, document, document onto the wiki.
|
||||||
* extend to ~~multiple languages ([VALL-E X](https://arxiv.org/abs/2303.03926)) and~~ addditional tasks ([SpeechX](https://arxiv.org/abs/2308.06873)).
|
* extend to ~~multiple languages ([VALL-E X](https://arxiv.org/abs/2303.03926)) and~~ addditional tasks ([SpeechX](https://arxiv.org/abs/2308.06873)).
|
||||||
- training additional tasks needs the SpeechX implementation to be reworked.
|
- training additional tasks needs the SpeechX implementation to be reworked.
|
||||||
|
- this requires a good foundational model before extending it to transfer tasks onto.
|
||||||
* improve throughput (despite peaking at 120it/s):
|
* improve throughput (despite peaking at 120it/s):
|
||||||
- properly utilize RetNet's recurrent forward / chunkwise forward passes (does not seem to want to work no matter how the model is trained).
|
- properly utilize RetNet's recurrent forward / chunkwise forward passes (does not seem to want to work no matter how the model is trained).
|
||||||
- utilize an approach similar to [FasterDecoding/Medusa](https://github.com/FasterDecoding/Medusa/) with additional heads for decoding N+1, N+2, N+3 AR tokens
|
- utilize an approach similar to [FasterDecoding/Medusa](https://github.com/FasterDecoding/Medusa/) with additional heads for decoding N+1, N+2, N+3 AR tokens
|
||||||
|
|
4
setup.py
4
setup.py
|
@ -44,7 +44,7 @@ setup(
|
||||||
"encodec>=0.1.1",
|
"encodec>=0.1.1",
|
||||||
"phonemizer>=2.1.0",
|
"phonemizer>=2.1.0",
|
||||||
"matplotlib>=3.6.0",
|
"matplotlib>=3.6.0",
|
||||||
"numpy==1.23.0",
|
"numpy",
|
||||||
"omegaconf==2.0.6",
|
"omegaconf==2.0.6",
|
||||||
"tqdm>=4.64.1",
|
"tqdm>=4.64.1",
|
||||||
"humanize>=4.4.0",
|
"humanize>=4.4.0",
|
||||||
|
@ -60,6 +60,8 @@ setup(
|
||||||
"h5py",
|
"h5py",
|
||||||
"torchscale @ git+https://git.ecker.tech/mrq/torchscale",
|
"torchscale @ git+https://git.ecker.tech/mrq/torchscale",
|
||||||
"prodigyopt @ git+https://github.com/konstmish/prodigy",
|
"prodigyopt @ git+https://github.com/konstmish/prodigy",
|
||||||
|
|
||||||
|
"descript-audio-codec",
|
||||||
],
|
],
|
||||||
url="https://git.ecker.tech/mrq/vall-e",
|
url="https://git.ecker.tech/mrq/vall-e",
|
||||||
)
|
)
|
||||||
|
|
|
@ -512,10 +512,10 @@ class Inference:
|
||||||
use_encodec: bool = True
|
use_encodec: bool = True
|
||||||
use_dac: bool = True
|
use_dac: bool = True
|
||||||
|
|
||||||
|
# shit that doesn't work
|
||||||
recurrent_chunk_size: int = 0
|
recurrent_chunk_size: int = 0
|
||||||
recurrent_forward: bool = False
|
recurrent_forward: bool = False
|
||||||
|
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def dtype(self):
|
def dtype(self):
|
||||||
if self.weight_dtype == "float16":
|
if self.weight_dtype == "float16":
|
||||||
|
@ -562,7 +562,7 @@ class Config(_Config):
|
||||||
tokenizer: str = "./tokenizer.json"
|
tokenizer: str = "./tokenizer.json"
|
||||||
|
|
||||||
sample_rate: int = 24_000
|
sample_rate: int = 24_000
|
||||||
variable_sample_rate: bool = True
|
variable_sample_rate: bool = True # for DAC, this will override the model automatically resampling to 44KHz.
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def distributed(self):
|
def distributed(self):
|
||||||
|
|
|
@ -370,9 +370,8 @@ class Dataset(_Dataset):
|
||||||
|
|
||||||
# shuffle it up a bit
|
# shuffle it up a bit
|
||||||
prom_length = 0
|
prom_length = 0
|
||||||
if cfg.experimental:
|
if cfg.experimental and False:
|
||||||
trim_length = random.randint(cfg.dataset.frames_per_second * 3, cfg.dataset.frames_per_second * 6) # [3 seconds, 6 seconds]
|
trim_length = max(2, int(np.random.normal(loc=5, scale=1.25) * cfg.dataset.frames_per_second))
|
||||||
#trim_length = max(2, int(np.random.normal(loc=5, scale=1.25) * cfg.dataset.frames_per_second))
|
|
||||||
else:
|
else:
|
||||||
trim_length = int(cfg.dataset.prompt_duration * cfg.dataset.frames_per_second) + random.randint(-cfg.dataset.frames_per_second, cfg.dataset.frames_per_second)
|
trim_length = int(cfg.dataset.prompt_duration * cfg.dataset.frames_per_second) + random.randint(-cfg.dataset.frames_per_second, cfg.dataset.frames_per_second)
|
||||||
|
|
||||||
|
|
|
@ -125,7 +125,7 @@ class AR_NAR(Base):
|
||||||
if n_levels == self.n_resp_levels:
|
if n_levels == self.n_resp_levels:
|
||||||
# might be better to have this decided on the dataloader level
|
# might be better to have this decided on the dataloader level
|
||||||
|
|
||||||
if cfg.experimental and False:
|
if cfg.experimental:
|
||||||
# makes higher levels less likely
|
# makes higher levels less likely
|
||||||
def generate( lo=0, hi=8 ):
|
def generate( lo=0, hi=8 ):
|
||||||
index = lo
|
index = lo
|
||||||
|
|
|
@ -197,7 +197,7 @@ try:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("Error creating `LLamaXformersAttention`:", e)
|
print("Error creating `LLamaXformersAttention`:", e)
|
||||||
|
|
||||||
def replace_attention( model, impl, verbose=Valse ):
|
def replace_attention( model, impl, verbose=False ):
|
||||||
device = next(model.parameters()).device
|
device = next(model.parameters()).device
|
||||||
dtype = next(model.parameters()).dtype
|
dtype = next(model.parameters()).dtype
|
||||||
attentions = [k.split('.') for k, m in model.named_modules() if isinstance(m, LlamaAttention)]
|
attentions = [k.split('.') for k, m in model.named_modules() if isinstance(m, LlamaAttention)]
|
||||||
|
|
Loading…
Reference in New Issue
Block a user