added some more notes, tweaks (RIP DAC, it's over)

This commit is contained in:
mrq 2025-04-17 20:24:40 -05:00
parent 9e27d2e02e
commit 98d1d8cb1e
7 changed files with 81 additions and 50 deletions

View File

@ -57,16 +57,18 @@ For audio backends:
- encoding audio will use the `encodec` backend automagically, as there's no EnCodec encoder under `vocos`
* [`descript-audio-codec`](https://github.com/descriptinc/descript-audio-codec): boasts better compression and quality, but has issues with model convergence.
- models at 24KHz + 8kbps will NOT converge in any manner.
- models at 44KHz + 8kbps seems harder to model its "language", and the NAR side of the model suffers greatly.
- models at 44KHz + 8kbps will work for lower codebook levels, but higher codebook levels will ***always*** have issues
* this seems to be inherent to the codec itself and not the model, as separate implementations have this problem
* [`nvidia/audio-codec-44khz`](https://huggingface.co/nvidia/audio-codec-44khz): boasts even better compression and quality
- this codec employs FSQ instead of RVQ.
* this doesn't seem to have any problems inherent to the codec itself, but instead inherent to FSQ codecs in general
#### Descript-Audio-Codec
Descript-Audio-Codec was thoroughly tested for promising much, much cleaner output audio, as this model encodes/decodes at 44.1KHz, rather than EnCodec's 24KHz.
However, due to the nature of the codec, simply throwing it at an attention-based transformer proves to be painful, as the model *heavily* suffers from noisy output in the higher half of the RVQ levels.
* the solution may be to simply encode / decode with *all* RVQ levels in one pass.
However, due to the nature of the codec, simply throwing it at an attention-based transformer proves to be painful, as the model *heavily* suffers from noisy output in the higher half of the codebook levels.
* the solution may be to simply encode / decode with *all* codebook levels in one pass.
Ironically, testing through erroneously encoded audio (feeding 24KHz audio without upsampling to 44.1KHz) proved to have "cleaner" but bad utterances.
@ -107,8 +109,8 @@ This script handles taking either raw input audio, or processed encoded audio, a
* For raw input audio, the MFCC (Mel-frequency cepstrum coefficients) are extracted as features from the waveform, and the cosine similarities are compared against every other utterance for a given speaker.
* This works *fine*, as this is adequately accurate and does not require a model to already exist.
* For the encoded audio, the audio codes are passed through the model's embedding, summed to one "token", and the cosine similarities are compared to score the top-K similar speakers.
* By default, the output response embedding is used, and each RVQ level is summed together to leave one sequence.
* In theory this should be better as the model may have its own features per RVQ code+level, but still requires a model to already be trained.
* By default, the output response embedding is used, and each codebook level is summed together to leave one sequence.
* In theory this should be better as the model may have its own features per codebook + level, but still requires a model to already be trained.
* The original encoding model's embeddings can also be used, or the last hidden states passed through the model, instead, but seems overkill.
When processing a dataset, this requires already having accompanying metadata generated through `vall_e.data --action=metadata --yaml=./your/training/config.yaml`.

View File

@ -17,7 +17,7 @@ While the original paper called for a separate AR model and a NAR model, by trea
## The AR (Autoregressive) Model
The AR is responsible for generating the first RVQ level of the audio codes for a given output. References to "outputs from the AR" refers to this level, as it contibutes to the final waveform the most.
The AR is responsible for generating the first codebook level of the audio codes for a given output. References to "outputs from the AR" refers to this level, as it contibutes to the final waveform the most.
* Some models may refer to this level as the "coarse" level.
* The benefit of autoregressively decoding for this code is that it offers better output while also "encoding" the duration within the sequence itself, as the stop token will depend on the length of the sequence.
* The downside is that it does take most of the compute time to iterate through the sequence one step at a time.
@ -40,20 +40,20 @@ Compared to non-autoregressive decoding, I personally feel that autoregressive e
## The NAR (Non-autoregressive) Model
The NAR is responsible for generating the remaining RVQ levels of the audio codes for a given output. References to the "outputs from the NAR" refers to the underlying "levels" for a given waveform, as each further levels contributes to the final waveform less significantly than the previous.
The NAR is responsible for generating the remaining codebook levels of the audio codes for a given output. References to the "outputs from the NAR" refers to the underlying "levels" for a given waveform, as each further levels contributes to the final waveform less significantly than the previous.
* Some models may refer to this level as the "fine" level.
As decoding is done non-autoregressively, the model can process tokens "in place" and have them attended to one another in the past and future, thus speeding up output and allowing for "more accurate" outputs.
Non-autoregressive training is performed by having the input tokens from the previous RVQ level predict the next level's token in place. The output logits are in the same position, and do not require further modifications as required for the AR.
Non-autoregressive training is performed by having the input tokens from the previous codebook level predict the next level's token in place. The output logits are in the same position, and do not require further modifications as required for the AR.
One problem exhibited from a NAR is producing arfifacts ("crust") in the final waveform. I believe this is a confidence problem where the wrong token is inferred.
* Unfortunately, one solution is to simply train a separate NAR, as this should help bolster the model's NAR capabilities without the AR influencing things, as I imagine being able to both causally and parallel-ly decode tokens harms things.
* This is backed by the used `cfg.model.experimental.rvq_levels_p` distribution affecting the model's AR capabilities, as increasing the NAR's share in training causes the AR to perform *less*.
* However, this may be simply wrong, but checkpoints that used such distributions felt lobotomized.
* Another solution that may help is to provide two token dropout methods:
* `token_dropout_error`: This will randomly nudge a small percentage of tokens from the prior RVQ level to simulate wrong tokens being predicted.
* `token_dropout_rate`: This will randomly mask off tokens from the prior RVQ level with a mask token, to try and have the model not-strongly-rely on the given input.
* `token_dropout_error`: This will randomly nudge a small percentage of tokens from the prior codebook level to simulate wrong tokens being predicted.
* `token_dropout_rate`: This will randomly mask off tokens from the prior codebook level with a mask token, to try and have the model not-strongly-rely on the given input.
Sampling from the NAR absolutely necessitates a low temperature or to be greedily sampled, as higher temperatures lead to the aforementioned artifacts in the final waveform.
* This is mostly mitigated with a proper non-causal mask, but crust still emerges at higher temperatures.
@ -83,7 +83,7 @@ The NAR-len model keeps things simple by:
* in theory, attention *could* deduce this from the amount of masked tokens vs unmasked tokens in the sequence.
* in reality, the model shouldn't really need to reference this anyways, as there's no reason for the model to make use of this information when it's trying to predict what *all* masked tokens should be.
* predicting the "duration" (the output audio token window) is kept within the model itself, by autoregressievly inferencing the duration for a given input prompt (text + audio).
* the model can already "know" the duration for a given prompt already from an AR RVQ level 0, by predicting when to output the stop token, so it makes sense to re-use the model for this.
* the model can already "know" the duration for a given prompt already from an AR codebook level 0, by predicting when to output the stop token, so it makes sense to re-use the model for this.
* the output length is a simple tokenized sequence where each token is a base-10 digit.
* it could be in any base, but it's simple to just treat each token ID as a digit, then cast the string to an int.
* this could literally also not be relying on an AR sequence to predict.
@ -93,7 +93,7 @@ The NAR-len model keeps things simple by:
Because the model already leverages the magic of attention to derive phoneme-alignment, such annotations are still not required (but they probably help with a naive sampler).
In theory, demasking for the NAR's RVQ level 0 can also be applied to the remaining RVQ levels to further improve the output from the remaining levels.
In theory, demasking for the NAR's codebook level 0 can also be applied to the remaining codebook levels to further improve the output from the remaining levels.
* this isn't necessary as the model already has a strong enough relationship between the prompt, the prior levels, and the targeted level.
* this is technically already offered with `cfg.model.experimental.token_dropout_rate` which mirrors masking, but experimentation has not been done to a large degree.
* there is a bit of a problem with properly implementing this, as the tokens aren't predicting themselves.
@ -123,7 +123,7 @@ Other solutions such as TorToiSe makes use of additional embeddings/classifiers
Classifiers are the final output head / projection layer that processes the last hidden states of a model into a probability distribution for each token.
Out of paranoia, each head is split for each macro-task (RVQ level, `stt`, and `len`), even though the core half of the model's training was with a single output head.
Out of paranoia, each head is split for each macro-task (codebook level, `stt`, and `len`), even though the core half of the model's training was with a single output head.
* It also helps with not needing to do some tricks by setting unwanted tokens to `-inf`.
### Text Embeddings
@ -171,23 +171,23 @@ However, due to the nature of the encoded audio, embedding the audio tokens requ
As EnCodec encodes audio across eight codebooks (and DAC's 44Khz audio under nine codebooks), our audio is encoded under a 2D space, rather than a simple 1D space like text does. Because of this, we require embeddings for *every* codebook level, effectively giving eight embedding heads for audio.
* Technically, this can be stored within a unified embedding head, but each layer is offset by 1024 (the number of tokens).
For the `prom` embedding, we can simply use each embedding for each layer. Each embedding level maps to its respective RVQ level.
For the `prom` embedding, we can simply use each embedding for each layer. Each embedding level maps to its respective codebook level.
However, the `resp` requires some extra care, as the model needs to both causally (AR) and parallel-ly (NAR) decode tokens.
* The first embedding level pertains to RVQ level 0 for the AR (`AR:0:0`) or NAR (`NAR:0:0`).
* The first embedding level pertains to codebook level 0 for the AR (`AR:0:0`) or NAR (`NAR:0:0`).
* This embedding predicts tokens within its own embedding.
* The remaining embedding levels maps to RVQ level 0 + n for the NAR (`NAR:L-1:L`).
* In other words, embedding level 1 => RVQ level 0, embedding level 2 => RVQ level 1, etc...
* The remaining embedding levels maps to codebook level 0 + n for the NAR (`NAR:L-1:L`).
* In other words, embedding level 1 => codebook level 0, embedding level 2 => codebook level 1, etc...
* I believe this is required because the model encodes which task to perform (rather than the attention heads), and which tokens to predict (rather than the classifiers)
* In other words, each embedding needs to be separated based on what tokens they do predict.
The `prom` and `resp` are split since, in theory, it helps the model know better what audio to source from, and what audio is part of the output sequence. In theory.
* The `text` embedding's robustness not only for reuse between each RVQ level, but for the `stt` task as well is a mystery.
* The `text` embedding's robustness not only for reuse between each codebook level, but for the `stt` task as well is a mystery.
Finally, the model *may* then sum each embedding level back down to one sequence, as defined under `cfg.model.experimental.audio_embedding_sums`.
* The resulant sum is not normalized by the length.
* It's not a requirement, as the model can still function only "seeing" the required RVQ level.
* However, it *may* help to have the model being able to "see" prior levels, as one RVQ level might depend on the prior level.
* It's not a requirement, as the model can still function only "seeing" the required codebook level.
* However, it *may* help to have the model being able to "see" prior levels, as one codebook level might depend on the prior level.
* This is mostly dependent on the underlying audio model being used, which would depend on how each residual is defined.
* A model not trained with summing embeddings can enable it without much impact, but a model trained on summing embeddings cannot go in the other way without further training.
* It *could* be beneficial to train a model under mixed modes, but requires experimentation.
@ -199,17 +199,17 @@ Either embeddings can be used to compute utterance similarity scores, as per `va
* I need to compare if this can be used as well for speaker similarities.
* The current implementation makes use of the `resp` embeddings for this, but the `proms` might be used instead (experimentation is needed for this).
#### RVQ Level Embedding
#### Codebook Level Embedding
This embedding hints what the target RVQ level of the audio codes is being targetted. This embedding is not required, but seems some architectures (Mamba) requires this.
This embedding hints what the target codebook level of the audio codes is being targetted. This embedding is not required, but seems some architectures (Mamba) requires this.
This *may* replace needing separate embeddings for each RVQ level, but experimentation is required to test this claim.
This *may* replace needing separate embeddings for each codebook level, but experimentation is required to test this claim.
### Tasks
The base model handles processing inputs into token sequences, per the requested task assigned to each input in a batch.
Most sequences follow a `<text><RVQ level><language><prompt><output>` sequence, but some tasks will receive the prompt as a list of tensors, instead.
Most sequences follow a `<text><codebook level><language><prompt><output>` sequence, but some tasks will receive the prompt as a list of tensors, instead.
The nitty gritty of how each task is implemented is documented under [./docs/data.md](/docs/data.md).
@ -275,7 +275,7 @@ However, due to the model being trained on phonemes, the resultant output is the
The primary benefit of this task is to provide a fast way to directly transcribe audio into the phonemes used annotate the dataset itself, but at the moment the reference model isn't accurate enough to rely on this.
* The other problem is it's very hard to validate this, as the output isn't in English, and requires processing through the model again to verify the transciption.
This task will follow a reverse sequence of `<audio><language><RVQ level><output>`.
This task will follow a reverse sequence of `<audio><language><codebook level><output>`.
#### Phonemize / Un-Phonemize
@ -331,10 +331,10 @@ Due to major enough differences, this code is segregated from the original `mode
## `models/ar_nar.py`
This script implements VALL-E as a unified autoregressive and non-autoregressive model, where RVQ-level 0 is inferenced autoregressively, the remaining levels are infereneced non-autoregressively, if requested.
* Since one model can be trained AR-ly and NAR-ly, RVQ-level 0 can also be trained non-autoregressively with diffusion-like masking.
This script implements VALL-E as a unified autoregressive and non-autoregressive model, where codebook level 0 is inferenced autoregressively, the remaining levels are infereneced non-autoregressively, if requested.
* Since one model can be trained AR-ly and NAR-ly, codebook level 0 can also be trained non-autoregressively with diffusion-like masking.
For training, this model handles preparing the batch provided through the dataloader according to a randomly sampled targetted RVQ-level.
For training, this model handles preparing the batch provided through the dataloader according to a randomly sampled targetted codebook level.
For inferencing, this will dynamically inference depending on the arguments provided.

View File

@ -32,6 +32,15 @@ The `AudioDecoder` projects the last hidden state through another feed-forward n
* It might not even be necessary to use an MLP, as the model was quick to fix itself after deleting-then-shrinking the feed-forward expansion factor to try and squeeze out throughput.
* because of this ablation, it's *probably* okay to just do per-codebook layer norm + an output head, but that experimentation is for another day.
### Ablations
For RVQ codecs, such as EnCodec and DAC, the `AudioEncoder.level_weights` can be ignored entirely without any problem.
For any codec, the `AudioEncoder.norm` can be omitted, as it doesn't make much sense to perform layer normalization pre-FFN when the input is just the embedding + codebook positioning embeddings.
* it *might* instead work when applying it to the input into the FFN rather than the input entirely, or applying it post-FFN on the residual connection.
Both `AudioEncoder.ffn` and `AudioDecoder.ffn` can have its expansion size adjusted and re-trained without much of an impact (for example, to downsize).
### `ResidualAudioEncoder/Decoder`
The implementation also includes an encoder/decoder targeted for residual codecs, but real-world testing shows that it does not perform anywhere near as well as the FSQ-targeted encoder/decoder setup.
@ -101,6 +110,7 @@ Both flavors were trained on the previously used dataset, but English-only utter
* Additional languages and the remaining 8 seconds to 12 seconds were re-introduced into the dataset. Non-English language performance needs to be evaluated, but it seems *fine*.
Additional tasks beyond text-to-speech (such as `ns`, `sr`, `stt`) were not trained for either models, as they're very low priority, and the implementation might have had logic to train for it gutted.
* `ns` and `sr` are being experimented with, but training is a ***huge*** pain as CPU-inferencing through the NAC is required per the dataloader
### Experimental Settings
@ -160,10 +170,8 @@ Additionally, this implementation paves the way a ton of neat features, such as:
However, output leaves a lot to be desired:
* despite what literature suggests, an FSQ codec is heavily favored against with the current approach
* each codebook's importance is effectively dependent on the speaker itself, so even having priority be a "learned" parameter is tough
* if DAC succeeds where `nvidia/audio-codec-44khz` failed, then this would prove it's a codec problem
* if DAC does not, then it's simply a 44KHz problem, where the codebooks are too saturated with information for the model to properly utilize
* although I doubt this is true, as the model still performs fine for a subset of speakers trained against
* if Encodec fails too, then this implementation has an inherent flaw
* RVQ codec don't have this problem as each level will always have the same type of importance (so much so that `AudioEncoder.level_weights` can be ignored for RVQ-codec-based models)
* this architecture does not remove the inherent problem DAC-based models have, where the higher codebooks contribute too much noise
* both the small and the large model seemed to have hit a "capacity" limit
* the "confidence" problem of the prior implementation seems to have emerged even for typical speakers
* some other quirks and emergent behaviors inherent to the model I'm not aware of / can't recall

View File

@ -15,7 +15,7 @@ Training is (obviously) *very* dependent on:
* annotating each utterance with it's top-k similar utterances through `vall_e.emb.similar` help with prompt adherence in any stage of training
* how patient you are
* the original (`base.py`) implementation serendipitously has a cirriculum that allows for speech to realize relatively fast with EnCodec (from what I remember)
* this is from how selecting which RVQ level to train naturally "scales the loss" for higher, less important levels, and the model doesn't train each level in parallel at all
* this is from how selecting which codebook level to train naturally "scales the loss" for higher, less important levels, and the model doesn't train each level in parallel at all
* the new (`base_v2.py`) implementation requires lots of patience, as it seems to require 8M samples for speech to properly realize
* this is from how several "stabilizers" are required to train it as every sequence is inherently trained in parallel, but not the loss calculation.
* the audio codec, to an extent
@ -29,7 +29,7 @@ Training is (not-so-obviously) not dependent on:
* per E2/F5's paper, a size of 1024 dim, 4x FFN, 16 heads, 24 layers might be preferable?
* it *probably* is necessary to have a larger model to better adhere to the reference clip, but experimentation is not at the point yet to verify this.
* the audio codec, to an extent
* for the old (`base.py`) implementation, EnCodec only seems to work well for it, as DAC might requires some hacks or patience for the higher RVQ levels to train, while `nvidia/audio-codec-44khz` requires an exotic training cirriculum, assumedly.
* for the old (`base.py`) implementation, EnCodec only seems to work well for it, as DAC might requires some hacks or patience for the higher codebook levels to train, while `nvidia/audio-codec-44khz` requires an exotic training cirriculum, assumedly.
* for the new (`base_v2.py`), given how EnCodec and `nvidia/audio-codec-44khz` both seem to behave the same, I assume this implementation is almost agnostic to any codec (as long as RVQ/FSQ-ness is signaled proper).
* each codec will have different cirriculum requirements and the ease for coherent speech to emerge from each levels will vary

View File

@ -12,7 +12,7 @@ import torch
import itertools
from .config import cfg
from .emb.qnt import trim, trim_random, repeat_extend_audio, concat_audio, merge_audio, decode_to_file, decode as decode_qnt, encode as encode_qnt, pad_codes_with_silence
from .emb.qnt import post_process, trim, trim_random, repeat_extend_audio, concat_audio, merge_audio, decode_to_file, decode as decode_qnt, encode as encode_qnt, pad_codes_with_silence
from .emb.g2p import encode as encode_phns
from .utils.sampler import PoolSampler, OrderedSampler, BatchedOrderedSampler, RandomSampler
from .utils.distributed import global_rank, local_rank, world_size, is_global_leader
@ -786,13 +786,7 @@ def _load_artifact(path, return_metadata=False, return_artifact=False, validate=
raise Exception(f"Artifact contains zero'd tensor: {path}")
codes = torch.from_numpy(codes.astype(int)).to(torch.int16)
# artifact was saved as a batch
if codes.dim() == 3:
codes = codes[0]
# (codebook, frame) => (frame, codebook)
if codes.shape[0] < codes.shape[1]:
codes = codes.t()
codes = post_process( codes )
if return_artifact:
return codes, artifact

View File

@ -444,6 +444,15 @@ def encode_from_file(path, device="cuda"):
Helper Functions
"""
def post_process( codes ):
# artifact was saved as a batch
if codes.dim() == 3:
codes = codes[0]
# (codebook, frame) => (frame, codebook)
if codes.shape[0] < codes.shape[1]:
codes = codes.t()
return codes
# DAC "silence": [ 568, 804, 10, 674, 364, 981, 568, 378, 731]
# trims from the start, up to `target`
@ -471,7 +480,8 @@ def trim( qnt, target, reencode=False, device="cuda" ):
end = end / cfg.dataset.frames_per_second * cfg.sample_rate
wav = decode(qnt, device=device)[0]
return encode(wav[start:end], cfg.sample_rate, device=device)[0].t()
res = encode(wav[start:end], cfg.sample_rate, device=device)
return post_process( res )
# trims a random piece of audio, up to `target`
# to-do: try and align to EnCodec window
@ -512,7 +522,7 @@ def interleave_audio( *args, audio=None ):
if i + 1 != len(qnts):
res.append( audio )
return res
return post_process( res )
# concats two audios together
def concat_audio( *args, reencode=False, device="cuda" ):
@ -524,11 +534,16 @@ def concat_audio( *args, reencode=False, device="cuda" ):
decoded = [ decode(qnt, device=device)[0] for qnt in qnts ]
combined = torch.concat( decoded )
return encode(combined, cfg.sample_rate, device=device)[0].t()
res = encode(combined, cfg.sample_rate, device=device)
return post_process( res )
# merges two quantized audios together
# requires re-encoding because there's no good way to combine the waveforms of two audios without relying on some embedding magic
def merge_audio( *args, device="cuda", scale=[] ):
# since this is more than likely being used in a dataloader worker, force disable CUDA since nvidia/audio-codec-44khz has problems where it'll try and use it despite the model not requesting it
if device == "cpu":
torch.cuda.is_available = lambda : False
qnts = [ *args ]
qnts = [ qnt for qnt in qnts if qnt is not None ]
decoded = [ decode(qnt, device=device)[0] for qnt in qnts ]
@ -548,7 +563,8 @@ def merge_audio( *args, device="cuda", scale=[] ):
decoded[i] = decoded[i] * scale[i]
combined = sum(decoded) / len(decoded)
return encode(combined, cfg.sample_rate, device=device)[0].t()
res = encode(combined, cfg.sample_rate, device=device)
return post_process( res )
# Get framerate for a given audio backend
def get_framerate( backend=None, sample_rate=None ):

View File

@ -947,7 +947,11 @@ class Base_V2(nn.Module):
# expand to list if not a list
proms = [ input ] if isinstance(input, torch.Tensor) else input
# iterate over the list to inject their tokens
token = torch.cat( [ prompt_input_to_token( input, quant_level ) for input in proms if input is not None ] )
parts = [ prompt_input_to_token( i, quant_level ) for i in proms if i is not None ]
for i, p in enumerate( parts ):
if p.dim() == 1:
parts[i] = p.repeat(p.shape[0], self.n_resp_levels)
token = torch.cat( parts )
if logits[batch_index].dim() < 3 and token.dim() >= 2:
token = token[..., 0]
@ -1197,18 +1201,25 @@ class Base_V2(nn.Module):
windows = [text_window, audio_window, audio_window]
for name, input in batch_input:
if isinstance(input, list):
shape = sum( [ i.shape[0] for i in input if isinstance(i, torch.Tensor) ] )
elif not isinstance(input, torch.Tensor):
continue
else:
shape = input.shape[0]
if name in ["phn", "text"]:
lens[0] = input.shape[0] + 1
lens[0] = shape + 1
elif name == "lang":
lens[0] += 2
elif name == "prom":
lens[1] = input.shape[0] + 1
lens[1] = shape + 1
elif name == "tone":
lens[1] += 2
elif name == "len":
lens[2] = 2
elif name == "resp":
lens[2] = input.shape[0]
lens[2] = shape
aux_lens.append( lens )
aux_windows.append( windows )