small cleanups
This commit is contained in:
parent
8aa1b2dabf
commit
33b7f81b94
10
README.md
10
README.md
|
@ -44,7 +44,7 @@ Training is very dependent on:
|
|||
|
||||
### Pre-Processed Dataset
|
||||
|
||||
A "libre" dataset can be found [here](https://huggingface.co/ecker/vall-e) under `data.tar.gz`.
|
||||
A "libre" dataset utilizing EnCodec quantized audio can be found [here](https://huggingface.co/ecker/vall-e) under `data.tar.gz`.
|
||||
|
||||
A script to setup a proper environment and train can be invoked with `./scripts/setup-training.sh`
|
||||
|
||||
|
@ -52,6 +52,8 @@ A script to setup a proper environment and train can be invoked with `./scripts/
|
|||
|
||||
> **Note** Preparing a dataset is a bit messy.
|
||||
|
||||
If you already have a dataset you want, for example your own large corpus, or for finetuning, you can use your own dataset instead.
|
||||
|
||||
0. Set up a `venv` with `https://github.com/m-bain/whisperX/`.
|
||||
+ At the moment only WhisperX is utilized. Using other variants like `faster-whisper` is an exercise left to the user at the moment.
|
||||
+ It's recommended to use a dedicated virtualenv specifically for transcribing, as WhisperX will break a few dependencies.
|
||||
|
@ -84,7 +86,7 @@ Two dataset formats are supported:
|
|||
- this will shove everything into a single HDF5 file and store some metadata alongside (for now, the symbol map generated, and text/audio lengths)
|
||||
- be sure to also define `use_hdf5` in your config YAML.
|
||||
|
||||
### Initializing Training
|
||||
### Training
|
||||
|
||||
For single GPUs, simply running `python3 -m vall_e.train yaml="./training/config.yaml`.
|
||||
|
||||
|
@ -114,6 +116,7 @@ Keep in mind that creature comforts like distributed training or `float16` train
|
|||
|
||||
Unfortunately, efforts to train a *good* foundational model seems entirely predicated on a good dataset. My dataset might be too fouled with:
|
||||
* too short utterances: trying to extrapolate longer contexts seems to utterly fall apart from just the `text` being too long.
|
||||
+ It might help to, instead, initially train with smaller utterances, train for two epochs, then increase the each sample length.
|
||||
* too tightly trimmed utterances: there being little to no space at the start and end might harm associating `<s>` and `</s>` tokens with empty utterances.
|
||||
* a poorly mapped phoneme mapping: I naively crafted my own phoneme mapping, where a HuggingFace tokenizer might supply a better token mapping.
|
||||
|
||||
|
@ -129,7 +132,7 @@ As the core of VALL-E makes use of a language model, various LLM architectures c
|
|||
* `llama`: using HF transformer's LLaMa implementation for its attention-based transformer, boasting RoPE and other improvements.
|
||||
* `mixtral`: using HF transformer's Mixtral implementation for its attention-based transformer, also utilizing its MoE implementation.
|
||||
* `bitnet`: using [this](https://github.com/kyegomez/BitNet/) implementation of BitNet's transformer.
|
||||
- Setting `bitsandbytes.bitnet=True` will make use of BitNet's linear implementation.
|
||||
- Setting `cfg.optimizers.bitnet=True` will make use of BitNet's linear implementation.
|
||||
|
||||
If you're training a true foundational model, consider which backend you want to use the most. `llama` backends can benefit from all the additional tech with it, while exotic ones like `retnet` or `bitnet` can't at the moment, but may leverage experimental gains.
|
||||
|
||||
|
@ -173,6 +176,7 @@ And some experimental sampling flags you can use too (your mileage will ***defin
|
|||
* clean up the README, and document, document, document onto the wiki.
|
||||
* extend to ~~multiple languages ([VALL-E X](https://arxiv.org/abs/2303.03926)) and~~ addditional tasks ([SpeechX](https://arxiv.org/abs/2308.06873)).
|
||||
- training additional tasks needs the SpeechX implementation to be reworked.
|
||||
- this requires a good foundational model before extending it to transfer tasks onto.
|
||||
* improve throughput (despite peaking at 120it/s):
|
||||
- properly utilize RetNet's recurrent forward / chunkwise forward passes (does not seem to want to work no matter how the model is trained).
|
||||
- utilize an approach similar to [FasterDecoding/Medusa](https://github.com/FasterDecoding/Medusa/) with additional heads for decoding N+1, N+2, N+3 AR tokens
|
||||
|
|
4
setup.py
4
setup.py
|
@ -44,7 +44,7 @@ setup(
|
|||
"encodec>=0.1.1",
|
||||
"phonemizer>=2.1.0",
|
||||
"matplotlib>=3.6.0",
|
||||
"numpy==1.23.0",
|
||||
"numpy",
|
||||
"omegaconf==2.0.6",
|
||||
"tqdm>=4.64.1",
|
||||
"humanize>=4.4.0",
|
||||
|
@ -60,6 +60,8 @@ setup(
|
|||
"h5py",
|
||||
"torchscale @ git+https://git.ecker.tech/mrq/torchscale",
|
||||
"prodigyopt @ git+https://github.com/konstmish/prodigy",
|
||||
|
||||
"descript-audio-codec",
|
||||
],
|
||||
url="https://git.ecker.tech/mrq/vall-e",
|
||||
)
|
||||
|
|
|
@ -512,10 +512,10 @@ class Inference:
|
|||
use_encodec: bool = True
|
||||
use_dac: bool = True
|
||||
|
||||
# shit that doesn't work
|
||||
recurrent_chunk_size: int = 0
|
||||
recurrent_forward: bool = False
|
||||
|
||||
|
||||
@cached_property
|
||||
def dtype(self):
|
||||
if self.weight_dtype == "float16":
|
||||
|
@ -562,7 +562,7 @@ class Config(_Config):
|
|||
tokenizer: str = "./tokenizer.json"
|
||||
|
||||
sample_rate: int = 24_000
|
||||
variable_sample_rate: bool = True
|
||||
variable_sample_rate: bool = True # for DAC, this will override the model automatically resampling to 44KHz.
|
||||
|
||||
@property
|
||||
def distributed(self):
|
||||
|
|
|
@ -370,9 +370,8 @@ class Dataset(_Dataset):
|
|||
|
||||
# shuffle it up a bit
|
||||
prom_length = 0
|
||||
if cfg.experimental:
|
||||
trim_length = random.randint(cfg.dataset.frames_per_second * 3, cfg.dataset.frames_per_second * 6) # [3 seconds, 6 seconds]
|
||||
#trim_length = max(2, int(np.random.normal(loc=5, scale=1.25) * cfg.dataset.frames_per_second))
|
||||
if cfg.experimental and False:
|
||||
trim_length = max(2, int(np.random.normal(loc=5, scale=1.25) * cfg.dataset.frames_per_second))
|
||||
else:
|
||||
trim_length = int(cfg.dataset.prompt_duration * cfg.dataset.frames_per_second) + random.randint(-cfg.dataset.frames_per_second, cfg.dataset.frames_per_second)
|
||||
|
||||
|
|
|
@ -125,7 +125,7 @@ class AR_NAR(Base):
|
|||
if n_levels == self.n_resp_levels:
|
||||
# might be better to have this decided on the dataloader level
|
||||
|
||||
if cfg.experimental and False:
|
||||
if cfg.experimental:
|
||||
# makes higher levels less likely
|
||||
def generate( lo=0, hi=8 ):
|
||||
index = lo
|
||||
|
|
|
@ -197,7 +197,7 @@ try:
|
|||
except Exception as e:
|
||||
print("Error creating `LLamaXformersAttention`:", e)
|
||||
|
||||
def replace_attention( model, impl, verbose=Valse ):
|
||||
def replace_attention( model, impl, verbose=False ):
|
||||
device = next(model.parameters()).device
|
||||
dtype = next(model.parameters()).dtype
|
||||
attentions = [k.split('.') for k, m in model.named_modules() if isinstance(m, LlamaAttention)]
|
||||
|
|
Loading…
Reference in New Issue
Block a user