more notes / re-enabled top-k/p samplers for new implementation
This commit is contained in:
parent
f8e1d110dc
commit
5fe01ffc6c
|
@ -4,7 +4,7 @@ This section aims to document the `_v2` class of models. Documentation here migh
|
|||
|
||||
Unlike the original, this implementation strives to operate on *all* codebooks at once with a full 44KHz bandwidth, rather than requiring the model to operate on one codebook level at a time at 24KHz audio.
|
||||
|
||||
Sample weights can be found [here](https://huggingface.co/ecker/vall-e/).
|
||||
Sample weights can be found [here](https://huggingface.co/ecker/vall-e/) under the `nemo-` prefix.
|
||||
|
||||
## Audio Codecs
|
||||
|
||||
|
@ -35,6 +35,7 @@ The `AudioDecoder` projects the last hidden state through another feed-forward n
|
|||
### Ablations
|
||||
|
||||
For RVQ codecs, such as EnCodec and DAC, the `AudioEncoder.level_weights` can be ignored entirely without any problem.
|
||||
* FSQ codecs might require this, as one test suggests removing this caused a nasty problem.
|
||||
|
||||
For any codec, the `AudioEncoder.norm` can be omitted, as it doesn't make much sense to perform layer normalization pre-FFN when the input is just the embedding + codebook positioning embeddings.
|
||||
* it *might* instead work when applying it to the input into the FFN rather than the input entirely, or applying it post-FFN on the residual connection.
|
||||
|
@ -88,6 +89,8 @@ More experimentation is needed for this modality, but seeing as the pure NAR app
|
|||
However, this modality was not trained for either models, as there seems to be some weird quirk when inferencing that's caught under CUDA, but not ROCm. This doesn't seem to "go away" with more training, unfortunately.
|
||||
* Additionally, I'm under the impression that separate `resps_embs` are required for causal/non-causal sequences, as the previous implementation inherently has this split.
|
||||
|
||||
However however, this modality doesn't seem to be necessary as the benefits of a pure NAR model outweighs the benefits of a pure AR one.
|
||||
|
||||
## Training Regimen
|
||||
|
||||
The `nemo-smaller-44khz-llama-8` model is a 512-dim, 12 layered, 8 headed attention-based transformer with rotary position embedding. Training was performed on four V100s with AMP+`float16` with a batch size of 8 samples per GPU, and an AdamW optimizer with adequate parameters (`1.0e-4` learning rate, betas of `[0.8, 0.95]`, weight_decay of `0.01`, linear warmup to 5K steps before holding) for 400K steps before introducing training for duration prediction in parallel. The dataloader sorts the dataset by duration, starting from 2 seconds and ending with 8 seconds-ed utterances. Training consists of computing the loss for each codebook level non-parallely (where a level is randomly assigned to a sample per a "normal" distribution) with each loss being weighed "normal"ly, for 70% of the epoch when speech starts to emerge. Then, the model was trained to compute the loss paralelly (where all levels have the loss computed) without weighing the loss per-level. Audio quality was lacking for most speakers, as the model failed to handle all codebook levels adequately. Additional training slowly helps, but by-the-numbers metrics don't show much improvement.
|
||||
|
@ -95,7 +98,6 @@ The `nemo-smaller-44khz-llama-8` model is a 512-dim, 12 layered, 8 headed attent
|
|||
* it's reasonable to assume that a lot of the nitty gritty like LR warmup and slowly introducing features are entirely unnecessary
|
||||
* the model *may* benefit from setting the dataloader to a speaker-based one, so it can balance different speakers.
|
||||
* due to some form of regression, training under bfloat16 (even with AMP) will cause the gradient norm to slowly grow along with the loss.
|
||||
* I'm honestly not too sure, since an experimental `dac-smaller-44khz-llama-9` model was trained with similar settings and it was rather stable.
|
||||
|
||||
The `nemo-larger-44khz-llama-8` model is similar to its immediate predecessor, with 1024-dim, 24 layers, and 16 heads. Training is similar where the only difference is with a learning rate of `3.0e-4`. Speech emerged quicker than its predecessor at `?`% of the epoch, but quality remains about the same.
|
||||
* increasing the de-facto batch size and lowering the learning rate seems to be necessary to edge out improvements in speaker similarity.
|
||||
|
@ -103,6 +105,7 @@ The `nemo-larger-44khz-llama-8` model is similar to its immediate predecessor, w
|
|||
|
||||
Training of both models experienced degradation in quality periodically, where the loss will rise, spike, then climb back down. It's reasonable to assume this came from duration sorting being the cause, as the model might somehow "overfit" based on duration, as this problem disappeared when re-initializing the dataloader to instead batch samples by durations, then shuffle the batches. However, training throughput significantly dropped for the larger model.
|
||||
* Training should *probably* only have the dataloader duration-ordered until speech does emerge, then train an epoch with shuffled durations.
|
||||
* Speaker similarity does improve when balancing for it without risking overfitment when it doesn't cull any speakers, but there's still inherent problems that suggest a capacity limit being reached.
|
||||
|
||||
The differences between the two models start to emerge on how the model can generalize. The smaller model seems to have trouble handling a variety of speakers and no inherent way of inferencing duration, while the larger model reaches its "capacity" much much later in training.
|
||||
|
||||
|
@ -110,7 +113,8 @@ Both flavors were trained on the previously used dataset, but English-only utter
|
|||
* Additional languages and the remaining 8 seconds to 12 seconds were re-introduced into the dataset. Non-English language performance needs to be evaluated, but it seems *fine*.
|
||||
|
||||
Additional tasks beyond text-to-speech (such as `ns`, `sr`, `stt`) were not trained for either models, as they're very low priority, and the implementation might have had logic to train for it gutted.
|
||||
* `ns` and `sr` are being experimented with, but training is a ***huge*** pain as CPU-inferencing through the NAC is required per the dataloader
|
||||
* `ns` and `sr` are being experimented with, but training is a ***huge*** pain as CPU-inferencing through the NAC is required per the dataloader.
|
||||
* this could be worked around by baking a noised dataset rather than doing it at runtime (because EnCodec can easily do it, but not DAC/NeMO)
|
||||
|
||||
### Experimental Settings
|
||||
|
||||
|
@ -154,18 +158,24 @@ These settings should be avoided:
|
|||
|
||||
## Samplers
|
||||
|
||||
To-do: Remember what I was going to jot down here
|
||||
Sampling code is effectively the same, with the twist of instead outputting the logits for all codebooks at `dim=0`. Returned scores are before applying sampler settings, as the unfiltered scores are preferred when doing demasking.
|
||||
|
||||
Sampling code is effectively the same, with the twist of instead outputting the logits for all codebooks at `dim=0`.
|
||||
The NAR-demasking step will account for this automatically, and has dials and knobs to adjust. Per the web UI:
|
||||
* `Masked Only`: will update scores for previously masked tokens, and unmasked tokens will not get remasked.
|
||||
* This seems *fine* when backported to the old implementation, but the new one does not like doing this.
|
||||
* `Flattened`: will average out the scores for all codebook levels at a given timestep, effectively masking across all codebooks for a given timestep, rather than have it independently masked
|
||||
* This doesn't seem to have any changes, however. Greedy sampling with this on and off seems to always produce the same output, but there might be a bug.
|
||||
* `Remask`: will remask a random amount of previously unmasked tokens to keep them "updated" to the rest of the waveform by some ratio (that I forget but I think it's half the masking ratio)
|
||||
* This offers some differences, but further evaluation is required to see if this is more of a good than a bad.
|
||||
|
||||
The NAR-demasking step will account for this automatically, and has dials and knobs to adjust whether to mask off independent of other codebook levels, or for all codebook levels at a given timestep.
|
||||
Basic sampler modifiers like `top-k`/`top-p` need to be evaluated but are available. Exotic samplers from the prior implementation are not included as they're bandaids applied to a modality that does not benefit from it.
|
||||
|
||||
## Benefits and Caveats
|
||||
|
||||
To be evaluated thoroughly.
|
||||
* The model seems pretty quick, even for the large model.
|
||||
* The smaller model seems small enough for CPU-only inferencing
|
||||
* Despite its poor zero-shot performance, it could be perfectly fine for finetuning.
|
||||
* Despite its lacking all-around zero-shot performance, it could be perfectly fine for finetuning.
|
||||
|
||||
At a glance, compared to the prior model setup, this implementation allows for the model to better represent speech as it's able to see the entire signal and account for it in its latent space, rather than only operate on specific levels of it at a time.
|
||||
|
||||
|
@ -180,6 +190,8 @@ However, output leaves a lot to be desired:
|
|||
* each codebook's importance is effectively dependent on the speaker itself, so even having priority be a "learned" parameter is tough
|
||||
* RVQ codec don't have this problem as each level will always have the same type of importance (so much so that `AudioEncoder.level_weights` can be ignored for RVQ-codec-based models)
|
||||
* this architecture does not remove the inherent problem DAC-based models have, where the higher codebooks contribute too much noise
|
||||
* despite most of the dataset being resampled to 44KHz before encoding, the output audio has some inherent quality benefits (over the prior implementation's EnCodec based model) when it doesn't exhibit the below problem
|
||||
* both the small and the large model seemed to have hit a "capacity" limit
|
||||
* the "confidence" problem of the prior implementation seems to have emerged even for typical speakers
|
||||
* some other quirks and emergent behaviors inherent to the model I'm not aware of / can't recall
|
||||
* some other quirks and emergent behaviors inherent to the model I'm not aware of / can't recall
|
||||
* such as the demasking sampler loop being quite particular
|
|
@ -1384,7 +1384,6 @@ class Base_V2(nn.Module):
|
|||
logits = [ logit[..., -l:, :] for l, logit in zip(seq_lens, logits) ]
|
||||
|
||||
# perform min_p filtering of our logits
|
||||
"""
|
||||
if min_p > 0.0:
|
||||
logits = [ min_p_filtering(logit, min_p=min_p) for logit in logits ]
|
||||
|
||||
|
@ -1395,7 +1394,6 @@ class Base_V2(nn.Module):
|
|||
# do top-no logit processing
|
||||
if top_no > 0.0:
|
||||
logits = [ top_no_logits_processing(logit) for logit in logits ]
|
||||
"""
|
||||
|
||||
probabilities = [ F.softmax(logit, dim=-1) for logit in logits ]
|
||||
scores = [ torch.max(prob, -1)[0] for prob in probabilities ]
|
||||
|
|
Loading…
Reference in New Issue
Block a user