fixed inferencing since I did delete the len_emb, some more notes on the model since it seems I just had bad experimental settings
This commit is contained in:
parent
61de653ad9
commit
9a7458cf17
|
@ -21,7 +21,7 @@ In theory, RVQ codecs should work better, as "importance" is consolidated in lev
|
|||
* The glamor of `nvidia/audio-codec-44khz` might not be so glamorous as the codebooks might be too dense for a model to easily operate on efficiently, as well as the codec's encoder/decoder being ***slow*** on ROCm.
|
||||
* in other words, DAC might be preferable as a 44KHz medium.
|
||||
* this might simply be a problem that can be "worked out" with more training time, hopefully, just as the "low confidence of higher codebook level" problem eventually works itself out.
|
||||
* this might also simply just be tied to the model's ability to follow closely to the prompt, as it seems more training does somewhat help
|
||||
* this might also simply just be tied to the model's ability to follow closely to the prompt, as it seems more training does somewhat help out, and there doesn't seem to be a specific codebook that has confidence issues on severely unseen speakers.
|
||||
|
||||
## `AudioEncoder` / `AudioDecoder`
|
||||
|
||||
|
@ -29,11 +29,12 @@ Because this model operates on the full audio sequence at once, extra care is re
|
|||
|
||||
The `AudioEncoder` embeds each codebook level (and injects level-position embedding information), stacks it, then passes it through an MLP ( / residual feedforward network ), then weighs each level through learned weights before summing it down to one sequence.
|
||||
* I feel most of this is kind of overkill, since I believe layer 0 of the underlying model could do this better, but it might also allow better tuning of the model's "encoder" with an explicit one over an inherent one.
|
||||
* Attention could also be used in place of the learned weights, as some speakers *could* prioritize different codebooks levels for FSQ sequences.
|
||||
* Attention could also be used in place of the learned weights, as different speakers *will* have different priorities in the audio spectrum, but I imagine this might end up as a learned feature that emerges within the attention heads of the underlying model itself.
|
||||
|
||||
The `AudioDecoder` projects the last hidden state through another feed-forward network (non-residual, with its own pre-layer norm). The decoder can be configured to either share the head for all levels, or dedicate a head for each level.
|
||||
* I feel non-shared heads might also be overkill, but allows for the decoder to better-er extract the dedicated codebook level from the last hidden state.
|
||||
* It might not even be necessary to use an MLP, as the model was quick to fix itself after deleting-then-shrinking the feed-forward expansion factor to try and squeeze out throughput.
|
||||
* because of this ablation, it's *probably* okay to just do per-codebook layer norm + an output head, but that experimentation is for another day.
|
||||
|
||||
### `ResidualAudioEncoder/Decoder`
|
||||
|
||||
|
@ -41,13 +42,27 @@ The implementation also includes an encoder/decoder targeted for residual codecs
|
|||
|
||||
This might be simply from it relying on cross-attention to deduce codebook level importance, rather than using an bone-standard feed-forward network with learned weighting of the codebooks (since the codebooks should always have a fixed relationship).
|
||||
|
||||
## Pure NAR
|
||||
## Modalities
|
||||
|
||||
The same core modalities are supported when inferencing all codebooks in parallel.
|
||||
|
||||
While the model *can* still be trained as a hybrid AR/NAR, the reference `nemo-*-llama-8` family of models are trained purely as a masked NAR transformer, due to:
|
||||
* the previous implementation nicely handled storing embeddings of different modalities, while this implementation does not have an inherent mechanism to do so without relying on some form of additional weights somewhere
|
||||
* this is *probably* because the embeddings signal to the modal whether to predict tokens in place or predict the next token in the sequence
|
||||
* *technically* the NAR sequences can be trained to predict the next token instead, but that's completely untested and may cause problems
|
||||
* even training with a causal (triangle) attention mask lobotomizes the model severely
|
||||
* I don't think it's worth the compute at the moment to brute force through it
|
||||
* the original implementation of NAR-demasking showed the model was too flawed to be used when naively using a causal (triangle) mask, so I would not want to tempt fate to ruin the model
|
||||
* an autoregressive decoder *could* be emulated with decoding in chunks
|
||||
* in addition to some additional lines of code to add in (which would probably just re-use the "rolling context" feature), an attention mask similar to sliding attention is required I imagine
|
||||
|
||||
### Pure NAR
|
||||
|
||||
Like the previous implementation, this model can operate entirely non-autoregressively (and with non-causal attention) as a masked transformer. The demasking inference loop is the same as the previous implementation, where each demasking step can mask off an entire timestep on the sum of the logit scores, or independently (where each level has its own mask).
|
||||
|
||||
Unlike the previous implementation, duration prediction is trained in parallel with the base `tts` task, where the output feature is always at the separator after the input prompt. This moves away from the kludge of treating the duration as an extra "language" task with a vocab size of `11`, and decoded autoregressively.
|
||||
Unlike the previous implementation, duration prediction is trained in parallel with the base `tts` task, where the output feature is always at the separator after the input prompt. This moves away from the kludge of treating the duration as an extra "language" task with a vocab size of `11`, and decoded autoregressively, while allowing some wiggle room in the duration as it's no longer sampled using logits.
|
||||
|
||||
## Pure AR
|
||||
### Pure AR
|
||||
|
||||
Unlike the previous implementation, this model can also operate entirely autoregressively as a causal transformer, where each step samples *all* codebooks at one code-frame.
|
||||
|
||||
|
@ -70,6 +85,39 @@ Both flavors were trained on the previously used dataset, but English only (as I
|
|||
|
||||
Additional tasks beyond text-to-speech (`tts`) were not trained for either models, as they're very low priority, and the implementation might have had logic to train for it gutted.
|
||||
|
||||
### Experimental Settings
|
||||
|
||||
Currently, both models are trained using these experimental flags:
|
||||
```
|
||||
unified_position_ids: False # per-segment position IDs
|
||||
|
||||
rvq_levels_p: "equal" # distribution of codebook levels to target training for
|
||||
audio_level_loss_factors: "normal" # distribution of loss weights per codebook (should be "equal" when speech is confident enough)
|
||||
|
||||
masking_train_p: 1.0 # pure AR
|
||||
masking_ratio: 0.8 # fixed mask ratio proves to be better
|
||||
ignore_inputs_for_loss: True # False is not implemented
|
||||
use_segmented_attention_mask: True # restricts each section within its own section + prior section (in other words, does not let the text/prom see further into the future outside of its segment)
|
||||
use_streamlined_calc_loss: True # False has no effect now
|
||||
|
||||
noncausal_masks: True # creates non-causal masks
|
||||
resp_parallel_training: True # trains all codebook levels in parallel
|
||||
len_parallel_training: True # trains length duration alongside normal training
|
||||
|
||||
cfg_cond_dropout_p: 0.02 # was originally 0.3, but I think it's too much after a while
|
||||
cfg_prom_dropout_p: 0.01 # was originally 0.2
|
||||
|
||||
use_raw_text_p: 0.1 # I don't know what's a good value, and I haven't tried inferencing with raw text yet
|
||||
```
|
||||
|
||||
These settings should be avoided:
|
||||
* `predict_causally`: forces the model to always predict the next token instead of the token in place, but untested for an actual model
|
||||
* the original NAR-demasking experiment suggests this probably is fine, but I don't want to take any risks
|
||||
* `logit_normalization`: *should* have some regularlization or whatever for logits, but in reality lobotomizes inferencing output.
|
||||
* `parallel_attention_mask_dropout`: this governs the rate of flipping to a causal (triangle) mask for training
|
||||
* there's *some* reason to do this ablation, but it ruins the model (but the model can easily recover if erroneously trained with this)
|
||||
* the model might eventually train itself to work around this, or it might need to be aware of this from the beginning, but it's not something to toy with.
|
||||
|
||||
## Benefits and Caveats
|
||||
|
||||
To be evaluated, as additional training time is required, despite progression seemingly plateu-ing.
|
||||
|
@ -80,6 +128,8 @@ Additionally, this implementation paves the way a ton of neat features, such as:
|
|||
* live playback through autoregressive inferencing, as all codebooks are predicted for each step
|
||||
* could also be "mocked" by doing NAR-len demasking in chunks
|
||||
* inherent audio upscaling, as the model is trained on a 44KHz codec
|
||||
* some other features I can't recall
|
||||
|
||||
However, I'm not sure if the additional complexity justifies it.
|
||||
* the current hurdle is that speaker similarity is ***dismal***
|
||||
* the current hurdle is that speaker similarity is ***dismal***
|
||||
* parallel inferencing on all codebooks might have enough of a performance hit that sequentially inferencing the codebooks might be preferable
|
|
@ -617,10 +617,6 @@ class Base_V2(nn.Module):
|
|||
# Audio length prediction task
|
||||
# Sequence: <phn><sep><rvq lvl><prom><sep><len>
|
||||
elif task_type == "len":
|
||||
# throw an error so we don't silently train without this
|
||||
if self.len_emb is None:
|
||||
raise Exception(f"Requesting task `{task_type}` but corresponding embedding is not defined.")
|
||||
|
||||
# insert the phn prompt
|
||||
if phns_list is not None and phns_list[i] is not None:
|
||||
inputs[i].append( ( "phn", phns_list[i] ) )
|
||||
|
|
Loading…
Reference in New Issue
Block a user