imagine my disappointment when the epoch finished just for it to throw an exception

This commit is contained in:
mrq 2024-12-16 18:28:01 -06:00
parent 4a65ac9eb7
commit 8515038968
7 changed files with 113 additions and 73 deletions

View File

@ -47,24 +47,16 @@ The reference model (`ar+nar-llama-8`/`ar+nar-len-llama-8`):
* [x] train and release a ***good*** zero-shot model.
- for what it's worth it's decent enough for me to finally be happy with it.
* [ ] well-integrated training through the Web UI (without the kludge from ai-voice-cloning)
* [x] ~~explore alternative setups, like a NAR-only model or Descript-Audio-Codec~~
- the current experiment of an AR length-predictor + NAR for the rest seems to fall apart...
- Descript-Audio-Codec 44KHz has NAR issues, but this *might* be user error.
* [x] ~~explore better sampling techniques~~
- the AR doesn't *need* exotic sampling techniques, as they're bandaids for a bad AR.
- the NAR benefits from greedy sampling, and anything else just harms output quality.
* [x] clean up the README, and document, document, document.
* [x] extend to multiple languages ([VALL-E X](https://arxiv.org/abs/2303.03926)).
- reference model is trained against English, Japanese, French, and German.
- [ ] improve multi-lingual support
- reference model is trained against English, Japanese, French, German, Korean, and Chinese (Mandarin?).
- [x] improve multi-lingual support
- [ ] improve cross-lingual support
* [ ] extend to addditional tasks ([SpeechX](https://arxiv.org/abs/2308.06873)).
- `stt` (Speech-to-Text) seems to be working fine for the most part, but is very much a second-class feature.
- other tasks seem to require a ton of VRAM......
- SpeechX tasks might need to be reworked to fit well within the `NAR-len` context to make full use of masking (for example, for speech editing)
- ***possibly*** voice conversion through the `NAR-len` with clever demasking tricks (for example, the tokens that are masked are from the source voice)
* [ ] ~~extend using [VALL-E 2](https://arxiv.org/pdf/2406.05370)'s features (grouped code modeling + repetition aware sampling)~~
- desu these don't seem to be worthwhile improvements, as inferencing is already rather fast, and RAS is just a fancy sampler.
* [ ] audio streaming
- this *technically* can work without any additional architecture changes, just clever tricks with sampling-then-decoding-to-audio.
- something similar to HiFiGAN (or the one for TorToiSe) trained on the last hidden states of the AR *might* also enable an alternate way for streaming.
@ -80,16 +72,10 @@ The reference model (`ar+nar-llama-8`/`ar+nar-len-llama-8`):
* [ ] replace the phonemizer with something that doesn't depend on espeak
* [ ] train the model to handle text => phoneme (without a hit to the rest of the model)
* [ ] ...and phonemes => text
* [ ] allow raw text as input instead
- espeak is nice, but I can only really put my whole trust with phonemizing English.
- a small model trained to handle converting text to phonemes might work, but has it's own problems (another model to carry around, as accurate as the dataset it was trained against, requires training for each language... etc).
* [ ] using a pure text vocab rather than IPA phonemes (as a transformer should be "smart" enough to map text tokens)
* [ ] smarter/clever inferencing, such as:
* [x] "rolling" context, where the last generated sentence is the prefix for the next sentence.
* for the AR, stop inferencing sequences in the batch that has already hit its stop token
* [ ] explore exotic features like:
* using a pure text vocab rather than IPA phonemes (as a transformer should be "smart" enough to map text tokens)
* mixing multiple speakers through summing input prompt embeddings
* I do not expect this to work, but you never know...
* [ ] for the AR, stop inferencing sequences in the batch that has already hit its stop token
* [x] objective metrics such as WER / SIM-O
* [x] WER simply requires transcribing audio then computing word error rates through the transcriptions
* [x] SIM-O requires passing the raw waveform through a speaker-similarity model
@ -104,19 +90,9 @@ However, while this solution boasts being lightweight, there are some caveats fo
* wrangling it is a bit of a chore, as some voices work fine under the `AR` but not the `NAR-len`, and vice-versa
* some voices outright refuse to work without LoRA training
* some sampler settings works on some voices, but others need some tweaking
* for short durations, it excels, but despite training on longer durations, stability is less guaranteed
* subjugating an existing LLM architecture is a bit of a pain, as I would *love* to make full use of LLaMA niceties
* `hf`-ifying it is possible, but it'd be a chore to set up the tokenizer properly
* it still seems like the phase of the moon matters with how it wants to cooperate
* some eval tests it seems fine, other times issues like word errors will crop up
* the `NAR-len` requires CFGs > 2-ish to cooperate (or a prefix)
* this isn't *so* much of an issue, but this can lead to user error, and CFG incurs an additional sampling step per step.
* guidance distillation would be nice, but distillation in general harms finetuning (assuming this just as likely harms it)
* rolling context/prefix does solve this
* VALL-E Continuous (prefixing with the input prompt) could also fix this, but technically makes it one-shot and not zero-shot
* multi-lingual support is a bit of an afterthought
* supported non-English speakers have the confidence problem for some speakers but exacerbated
* there's a regression in the `ar+nar-len-llama-8` model with a decrease in speaker similarity.
## Notices and Citations
@ -124,7 +100,7 @@ Unless otherwise credited/noted in this repo or within the designated Python fil
- [EnCodec](https://github.com/facebookresearch/encodec) is licensed under CC-BY-NC 4.0. If you use the code to generate audio quantization or perform decoding, it is important to adhere to the terms of their license.
- This implementation was originally based on [enhuiz/vall-e](https://github.com/enhuiz/vall-e), but has been heavily, heavily modified over time. Without it I would not have had a good basis to muck around and learn.
- This implementation was originally based on [enhuiz/vall-e](https://github.com/enhuiz/vall-e), but has been heavily, heavily modified over time. Without it, I would not have had a good basis to muck around and learn.
```bibtex
@article{wang2023neural,
@ -143,3 +119,23 @@ Unless otherwise credited/noted in this repo or within the designated Python fil
year={2022}
}
```
```bibtex
@inproceedings{emilia,
author={He, Haorui and Shang, Zengqiang and Wang, Chaoren and Li, Xuyuan and Gu, Yicheng and Hua, Hua and Liu, Liwei and Yang, Chen and Li, Jiaqi and Shi, Peiyang and Wang, Yuancheng and Chen, Kai and Zhang, Pengyuan and Wu, Zhizheng},
title={Emilia: An Extensive, Multilingual, and Diverse Speech Dataset for Large-Scale Speech Generation},
booktitle={Proc.~of SLT},
year={2024}
}
```
```bibtex
@INPROCEEDINGS{librilight,
author={J. {Kahn} and M. {Rivière} and W. {Zheng} and E. {Kharitonov} and Q. {Xu} and P. E. {Mazaré} and J. {Karadayi} and V. {Liptchinsky} and R. {Collobert} and C. {Fuegen} and T. {Likhomanenko} and G. {Synnaeve} and A. {Joulin} and A. {Mohamed} and E. {Dupoux}},
booktitle={ICASSP 2020 - 2020 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)},
title={Libri-Light: A Benchmark for ASR with Limited or No Supervision},
year={2020},
pages={7669-7673},
note = {\url{https://github.com/facebookresearch/libri-light}},
}
```

View File

@ -1,7 +0,0 @@
# `ext/*`
This folder handles external model implementations, where the code is not easily offered as a package.
Currently, this just includes code for a RetNet, offered as a TorchScale-compatible implementation, or a HuggingFace-compatible implementation.
Comments and attributions are under its `__init__.py`.

13
docs/metrics.md Normal file
View File

@ -0,0 +1,13 @@
# `metrics.py`
This file provides helper functions for computing objective metrics, such as word-error rate (WER), character-error rate (CER), and speaker similarity (SIM-O).
## WER / CER
Word-error rate (WER) is simply computed by transcribing the requested input, and comparing its transcription against the target transcription.
Because of issues with normalization (and not having a robust normalization stack), both transcriptions are then phonemized, then the resultant phonemes are used for error rate calculations.
## SIM-O

View File

@ -55,6 +55,7 @@ One problem exhibited from a NAR is producing arfifacts ("crust") in the final w
* `token_dropout_rate`: This will randomly mask off tokens from the prior RVQ level with a mask token, to try and have the model not-strongly-rely on the given input.
Sampling from the NAR absolutely necessitates a low temperature or to be greedily sampled, as higher temperatures lead to the aforementioned artifacts in the final waveform.
* This is mostly mitigated with a proper non-causal mask, but crust still emerges at higher temperatures.
Traditional samplers do not seem to offer much effect in the output, as it seems the output from the NAR are decent enough.
@ -72,16 +73,19 @@ However, having a pure NAR is challenging, as you need to both explicitly provid
The NAR-len model keeps things simple by:
* training with a fixed masking ratio (80% of the tokens are masked and trained to predict the remaining tokens)
* [this paper](https://arxiv.org/abs/2406.05478v1) mentions a fixed ratio during training yields better results than randomly picking a masking ratio.
* randomly picking a duration ~~is actually very ungood and harms the model during training~~ actually doesn't matter much.
* randomly picking a duration ~~is actually very ungood and harms the model during training~~ ~~actually doesn't matter much~~ matters enough to warrant sticking with a fixed rate.
* theoretically, it should help later stages in demasking to better rely on the non-masked tokens, but who knows.
* in reality, it seems to harm the model being able to produce decent results in fewer steps.
* not including any specific timestep embedding information
* some solutions add in the (sinusoidal position'd) timestep embedding, either on top of the input embeddings, or as some normalization weight around the attention head (before and after).
* it does not seem to be necessary what-so-ever to require this, especially training under a fixed masking ratio.
* in reality, the model shouldn't really need to reference this anyways.
* in theory, attention *could* deduce this from the amount of masked tokens vs unmasked tokens in the sequence.
* in reality, the model shouldn't really need to reference this anyways, as there's no reason for the model to make use of this information when it's trying to predict what *all* masked tokens should be.
* predicting the "duration" (the output audio token window) is kept within the model itself, by autoregressievly inferencing the duration for a given input prompt (text + audio).
* the model can already "know" the duration for a given prompt already from an AR RVQ level 0, by predicting when to output the stop token, so it makes sense to re-use the model for this.
* the output length is a simple tokenized sequence where each token is a base-10 digit.
* it could be in any base, but it's simple to just treat each token ID as a digit, then cast the string to an int.
* this could literally also not be relying on an AR sequence to predict.
* some checkpoints of the model seems to adhere well to outputting silence at the end if the requested duration exceeds the actual duration.
* inferencing is a simple loop that simply takes the best masked-off k tokens per step, and remasks the remaining.
@ -96,9 +100,11 @@ In theory, demasking for the NAR's RVQ level 0 can also be applied to the remain
It is ***crucial*** to:
* avoid re-masking tokens that are already "good" enough (this can easily be done by "banning" them in the scoring process)
* without this, you ***will*** get stuttering and unaligned utterances. I do not know why this is such a big problem but I imagine this "interleaves" many different sequences between each step.
* (although token remasking shows that this isn't a strict requirement)
* use unfiltered/unprocessed logit scores:
* not that crucial, but helps stability.
* not that crucial, but helps stability, by using which part of the sequence was "confident enough" to keep.
* use a CFG strength of at least 2 (or a prefix)
* the output falls apart completely without this.
It is not required to train a model from scratch to use this modality, as training from existing weights works just as well, if not better (as it can piggyback off the original model).
* additional training is still required to help confidence issues and to condition the model to not fall apart for longer durations.
@ -110,6 +116,9 @@ The "magic" of subjugating a transformer for audio use lies within the ensemble
While embeddings *can* be tied to the output head, testing showed that the model ***really*** does not like to do this, although my implementation could very well be flawed.
With attention-based transformers, most embeddings can serve as a token itself and have the attention mechanism attend to it. Theoretically, there should be little to no functional differences between "tokenizing" an embedding, and summing a modifying embedding, but experimentation is needed for this assertion.
* EnCodec seems to function perfectly fine with summing and without, but other codecs such as Descript-Audio-Codec might absolutely require summing.
Other solutions such as TorToiSe makes use of additional embeddings/classifiers for each portion of the sequence as well.
### Classifiers
@ -126,14 +135,21 @@ Technically, due to how the audio embeddings are implemented, it's possible to o
These embeddings *could* instead be added on top of the input prompt embedding instead of serving as additional tasks (similar to injecting position embeddings), but additional experimentation is required to see if the model both can work under this and/or benefits from this.
These embeddings can also be substituted out for a "text semantic" embedding, rather than tokenized phonemes, as the text conditioning input.
* Additionally, methods like [BLT](https://github.com/facebookresearch/blt) can replace this instead, as patching the audio portion wouldn't gain much benefit due to it already being quantized audio.
#### Language Embedding
This embedding provides the requested language for the model to be aware of.
This *mostly* isn't necessary, but VALL-E X's paper mentions needing a token for the language itself, and other solutions like XTTS2 provides a language token as well.
In practice, this seems to help govern the accent / general mannerisms associated with that language. For example, prompting French or German with the language set to `en` will give typical foreigner speech of trying to speak a language they don't know.
In reality, this seems to help govern the accent / general mannerisms associated with that language.
* For examples:
* prompting French or German with the output language set to `en` will give typical foreigner speech of trying to speak a language they don't know.
* prompting a Japanese speaker with the output language set to `ko` or `zh` will offer little changes to the spoken language (at least no nuance I can hear as an EOP).
* Consequently, since this does tie to accents more, ***extreme*** attention is to be paid to the dialects being trained against, instead of naively grouping, say, all of Spanish to one language code.
* unfortunately, this does mean that audio annotated as English is dialect/accent-agnostic, per the dataset.
This embedding probably helps the model with being able to perform cross-lingual outputs, but I did not do any experimentations on a model without this, as the reference `ar+nar-llama-8` was trained with this from the beginning with the small Japanese in my dataset anyhow (and maybe the `ar+nar-retnet-8` experiment).
@ -143,7 +159,7 @@ This embedding *should* provide information on the tone for the model to output
Should, since I do not actually make use of this anywhere, and the model is not trained against any tones. I would need to annotate my dataset based on tones *and* pick which tones I do want.
This should most definitely help the model identify tone strongly even without needing to annotate for it, but it does an adequate already with maintaining tone from a given input prompt.
This should most definitely help the model identify tone strongly even without needing to annotate for it, but it does an adequate job already with maintaining tone from a given input prompt.
### Audio Embeddings
@ -154,8 +170,8 @@ As EnCodec encodes audio across eight codebooks (and DAC's 44Khz audio under nin
For the `prom` embedding, we can simply use each embedding for each layer. Each embedding level maps to its respective RVQ level.
Howver, the `resp` requires some extra care, as the model needs to both causally (AR) and parallel-ly (NAR) decode tokens.
* The first embedding level pertains to RVQ level 0 for the AR (`AR:0:0`).
However, the `resp` requires some extra care, as the model needs to both causally (AR) and parallel-ly (NAR) decode tokens.
* The first embedding level pertains to RVQ level 0 for the AR (`AR:0:0`) or NAR (`NAR:0:0`).
* This embedding predicts tokens within its own embedding.
* The remaining embedding levels maps to RVQ level 0 + n for the NAR (`NAR:L-1:L`).
* In other words, embedding level 1 => RVQ level 0, embedding level 2 => RVQ level 1, etc...
@ -163,7 +179,7 @@ Howver, the `resp` requires some extra care, as the model needs to both causally
* In other words, each embedding needs to be separated based on what tokens they do predict.
The `prom` and `resp` are split since, in theory, it helps the model know better what audio to source from, and what audio is part of the output sequence. In theory.
* The `text` embedding's robustness not only for reusing between each RVQ level, but as STT task as well is a mystery.
* The `text` embedding's robustness not only for reuse between each RVQ level, but for the `stt` task as well is a mystery.
Finally, the model *may* then sum each embedding level back down to one sequence, as defined under `cfg.model.experimental.audio_embedding_sums`.
* The resulant sum is not normalized by the length.
@ -176,6 +192,10 @@ Finally, the model *may* then sum each embedding level back down to one sequence
Additionally, it's *technically* possible to instead use the embeddings from the model used to encode the audio (for example, EnCodec's embeddings), but in theory this may exclude specific features the model has encoded within the embeddings.
Either embeddings can be used to compute utterance similarity scores, as per `vall_e.emb.similarity` for utterance similarities.
* I need to compare if this can be used as well for speaker similarities.
* The current implementation makes use of the `resp` embeddings for this, but the `proms` might be used instead (experimentation is needed for this).
#### RVQ Level Embedding
This embedding hints what the target RVQ level of the audio codes is being targetted. This embedding is not required, but seems some architectures (Mamba) requires this.
@ -196,6 +216,7 @@ The primary zero-shot text-to-speech synthesis `tts` task takes in a requested t
The model primarily functions in a zero-shot setting, where it does not need a guiding prefix, but few-shotting is possible through manual intervention.
* I believe the original VALL-E paper refers to this more as `VALL-E Continuous`, while some other TTS solutions follow this method by transcribing the input audio prompt as well.
* Guidiance prefixing is offered in the implementation, but right now is only exposed under "rolling context/prefix" through the web UI / CLI (where the previous segment is used as the prefix for the next).
Additional tasks are implemented in this project, but ***are yet to be trained for*** in the reference model (as some tasks require additional compute-cost).
@ -250,11 +271,13 @@ This task will follow a reverse sequence of `<audio><language><RVQ level><output
## Emergent Behavior
The model can be prompted in creative ways to yield some interesting behaviors:
* prompting without an input audio prompt will have the model generate a random voice at the "cost" of some unintelligible utterance at the beginning of the output response (despite doing no promptless training).
* prompting without an input audio prompt will have the model generate a random voice ~~at the "cost" of some unintelligible utterance at the beginning of the output response (despite doing no promptless training)~~.
* classifier-free-guidance-aware training does fix this, but this property emerges without it.
* the AR is much better with this property, as the `NAR-len` gets crusty sometimes.
* prompting with an input text prompt being the transcription of the input audio prompt will have the response follow very closely to the input prompt (despite not doing input=output training).
* this should allow for easy transcription editing without much fuss.
* the `NAR-len` greatly exhibits this property, although it sometimes does keep any noise in the background.
* extra care is required when doing this, as some checkpoints of the model will degrade completely the moment the prompt can't be directly referenced.
# `models/*`
@ -324,7 +347,7 @@ This script modifies modules of BitNet to play nicely with my existing code.
### `models/arch/llama.py`
This script modifes modules of LLaMA provided through `transformers`.
This script modifies modules of LLaMA provided through `transformers`.
A bulk of it pertains to modifying `LlamaAttention` and detecting available attention mechanisms, allowing for using different attention mechanisms:
* `torch.nn.functional.scaled_dot_product_attention`-based attention:
@ -389,16 +412,4 @@ This folder contains specific attention mechanisms.
Currently, only `fused.py` is provided, which implements fused attention through Triton.
Attributions are noted at the top of the respective file(s).
### `models/arch/mamba_vasqu`
This folder contains an implementation of Mamba2 as a HuggingFace-compatible model, and not requiring Triton.
Attributions are noted at the top of the respective file(s).
### `models/arch/retnet_syncdoth`
This folder contains scripts to modify modules within a RetNet model.
Attributions are noted at the top of the respective file(s).

View File

@ -32,13 +32,6 @@ This script contains code to handle sampling from a list of indices.
Each sampler can load and store a state dict.
## `utils/unsloth.py`
This script contains Unsloth, a VRAM-saving optimization that offloads the input tensors to CPU on a backwards pass.
This is mostly unncessary, as inputs are rather small themselves, but is offered nonetheless if needed through `cfg.optimizations.unsloth = True`
Attributions are noted at the top.
## `utils/utils.py`
@ -57,4 +50,26 @@ This script handles the necessary code for training, such as:
This script contains optimizations and additional code that require injecting or replacing modules.
Most configurations are offered through `cfg.optimization`.
Most configurations are offered through `cfg.optimization`.
## `utils/ext/`
This folder contains external code that can't be nicely referenced under a package.
Proper attribution is noted at the top of each file.
### `utils/ext/apollo.py`
This script contains [APOLLO](https://github.com/zhuhanqing/APOLLO), an optimizer that achieves ADAMW-like performance with very little memory cost.
In testing, this seems to work fine, and the memory gains (in comparison to Prodigyopt) under the normal-specced model allows you to double the batch size.
It's definitely usable under extremely low VRAM environments, and specifying `apollo-mini` will further shrink the memory requirements (but robustness is yet to be personally tested).
However, after a while, it seemed to cause some steps to either cause gradient overflow or NaNs that persist even when swapping back to `prodigyopt` (but I do not know if it's at the fault of `APOLLO` or just the model eventually hitting a point of instability).
### `utils/ext/unsloth.py`
This script contains Unsloth, a VRAM-saving optimization that offloads the input tensors to CPU on a backwards pass.
This is mostly unncessary, as inputs are rather small themselves, but is offered nonetheless if needed through `cfg.optimizations.unsloth = True`

View File

@ -258,6 +258,11 @@ class AR_NAR(Base):
remasking = sampling_kwargs.get("remasking", True)
max_steps = math.floor(max_steps * (end_noise - start_noise))
# to specify the initial mask used
mask_list = sampling_kwargs.pop("mask_list", None)
if mask_list is not None:
len_list = [ x.shape[0] for x in mask_list ]
len_list = [ clamp(l, min_length, max_length) for l in len_list ]
# force set CFG because too low / no CFG causes issues
@ -300,10 +305,17 @@ class AR_NAR(Base):
remask_p = 1.0 / (max_steps * 2) if remasking else 0
# pick the worst scoring tokens to mask off
masked_indices = [ score.topk( clamp( int( noise_p * seq_len + remask_p * seq_len ), 1, seq_len), dim=-1 ).indices for score, seq_len in zip(scores, len_list) ]
# mask off inputs
resps_list = [ resp.scatter(0, indices, self.stop_token) for resp, indices in zip( resps_list, masked_indices ) ]
# boolean mask
is_masked = [ resps == self.stop_token for resps in resps_list ]
if mask_list is None:
# mask off inputs
resps_list = [ resp.scatter(0, indices, self.stop_token) for resp, indices in zip( resps_list, masked_indices ) ]
# boolean mask
is_masked = [ resps == self.stop_token for resps in resps_list ]
else:
# mask off inputs
resps_list = [ resp.scatter(0, indices, mask) for resp, indices, mask in zip( resps_list, masked_indices, mask_list ) ]
# boolean mask
is_masked = [ resps == mask for resps, mask in zip( resps_list, mask_list ) ]
# timestep inputs
time_list = [ timestep for _ in range(batch_size) ]

View File

@ -102,14 +102,14 @@ def _non_blocking_input():
def _make_infinite_epochs(dl):
if dl.dataset.batches() == 0:
raise Exception("Empty dataset!")
while True:
if dl.dataset.index() == 0:
_logger.info("New epoch starts.")
total = dl.dataset.batches() - dl.dataset.index()
if total <= 0:
raise Exception("Empty dataset")
with tqdm(dl, "Epoch progress", dynamic_ncols=True, disable=not is_global_leader(), total=total) as pbar:
yield from pbar