added mixed modality AR+NAR-len to generate a short prefix through the AR, then inference with said prefix through the NAR-len (need to experiment with it more to ensure that the masked off tokens are the only tokens getting updated)
This commit is contained in:
parent
db64e6cb59
commit
67f7bad168
|
@ -9,7 +9,7 @@ At the time, state-of-the-art neural-based TTS solutions were sparing. TorToiSe
|
|||
# Why this VALL-E?
|
||||
|
||||
Unlike the paper, this VALL-E aims to:
|
||||
* be lightweight as possible, only requiring one model to load and use (and EnCodec/Vocos).
|
||||
* be lightweight as possible, only requiring one model to load and use (and EnCodec/Vocos as an audio encoder/decoder).
|
||||
+ Even the original VALL-E requires a separate AR and a NAR.
|
||||
* keep training and finetuning (be it the base model or through LoRAs) accessible to anyone.
|
||||
+ Bark was needlessly complex in providing even additional voices to use.
|
||||
|
@ -17,19 +17,7 @@ Unlike the paper, this VALL-E aims to:
|
|||
* provide decent zero-shot text-to-speech synthesis, both without requiring sampling adjustments and providing thorough sampler settings.
|
||||
* provide additional, easy to use functionality, that other solutions don't offer.
|
||||
|
||||
## Caveats
|
||||
|
||||
Despite how lightweight it is in comparison to other TTS's I've meddled with, there are still some caveats, be it with the implementation or model weights:
|
||||
* the audio embeddings have some quirks to having the AR's RVQ level 0 separate from the NAR's RVQ level 0 (sharing them caused some problems in testing)
|
||||
* the trainer / dataloader assumes there are zero variations between a speaker's utterances, and thus it can extract the basics of a speaker's features rather than deeper features (like prosidy, tone, etc.) when performing inferences.
|
||||
+ ~~however, trying to work around this would require training under `tts-c` (VALL-E continuous) mode or modifying an input prompt enough to where its quantized representation differs enough from the output response the prompt derives from.~~
|
||||
+ to remedy this, training benefits from calculating the most similar utterances for each utterance, and using that as the input prompt for training.
|
||||
* the trainer's default RVQ level distribution prioritizes lower RVQ levels over higher RVQ levels, as the lower levels contribute to the final waveform more; however, this leaves some minor artifacting that rises in the higher RVQ levels due to inaccuracy issues.
|
||||
+ summing the audio embeddings for later RVQ levels seems to help?
|
||||
+ `model.experimental.p_rvq_levels: [0,0,0,0,0,0,0,1,2,3,4,5,6,7]` seems to help?
|
||||
* speakers that aren't similar to an audiobook narrator voice has similarity issues due to the majority of training used `path`-based dataloader sampling instead of `speaker`-based (or `group`-based) dataloader sampling.
|
||||
+ although LoRAs help a ton for fixing results for a single voice.
|
||||
+ a diverse dataset in prosidy and speaker (such as a corpus sourced from dramatic media like video games) helps a ton, but still has issues for speakers not similar to any seen speakers.
|
||||
However, at this point and time, the implementation is rather divorced from VALL-E and its derivating papers, but the core principle is still followed.
|
||||
|
||||
## To-Do
|
||||
|
||||
|
@ -57,8 +45,7 @@ Despite how lightweight it is in comparison to other TTS's I've meddled with, th
|
|||
* [ ] speed up inferencing
|
||||
- KV caching both yields broken output and quadratically slow output, unless I'm doing something grossly wrong.
|
||||
- A pure HF model is the only way to fix this, but converting the model to one is a bit of a chore.
|
||||
- Speculative sampling seems overkill for small models (and in reality seems like it's better to just train a larger model).
|
||||
- Self-speculation through layer-skipping doesn't offer any tangible speedups, sadly.
|
||||
* [x] provide a pure NAR model that foregoes most of the inferencing slowdowns a regular AR+NAR model will provide.
|
||||
* [ ] replace the phonemizer with something that doesn't depend on espeak
|
||||
* [ ] train the model to handle text => phoneme (without a hit to the rest of the model)
|
||||
* [ ] ...and phonemes => text
|
||||
|
|
|
@ -15,6 +15,7 @@ While the original paper called for a separate AR model and a NAR model, and by
|
|||
## The AR (Autoregressive) Model
|
||||
|
||||
The AR is responsible for generating the first RVQ level of the audio codes for a given output. References to "outputs from the AR" refers to this level, as it contibutes to the final waveform the most.
|
||||
* Some models may refer to this level as the "coarse" level.
|
||||
* The benefit of autoregressively decoding for this code is that it offers better output while also "encoding" the duration within the sequence itself, as the stop token will depend on the length of the sequence.
|
||||
* The downside is that it does take most of the compute time to iterate through the sequence one step at a time.
|
||||
|
||||
|
@ -22,15 +23,22 @@ Autoregressive training is performed by having each token predict the next token
|
|||
|
||||
One way to work around the time cost is to instead decode more than one token at a time.
|
||||
* In theory, for a unified AR+NAR model, this *should* be an easy task, as the model can already decode tokens in parallel.
|
||||
* In reality, this isn't the case. Specifying a `cfg.model.experimental.causal_size > 1` will have the output sound *fine* every Nth timestep, as the following tokens aren't predictable enough.
|
||||
* In reality, this isn't the case. Specifying a `cfg.model.experimental.causal_size > 1` with adequate training will have the output sound *fine* every Nth timestep and every other timestep not so fine, as the following tokens aren't predictable enough.
|
||||
+ *However*, this may simply be a sampling problem, as this experiment was done with outdated ideas on how to sample the AR, and should be worth revisiting.
|
||||
* VALL-E 2's paper proposes merging code sequences together into one embedded token for a speedup, but their solution seems rather complex to warrant a fundamental retrain.
|
||||
|
||||
I personally feel that autoregressive encoding offers a specific-yet-hard-to-quantify expressive quality that the NAR (and pure NAR solutions) does not offer, but further testing is required to substantiate the claim.
|
||||
Sampling the AR does not necessarily require a specific sampling temperature, as:
|
||||
* lower temperatures follow the prompt better, at the cost of variety in the outputs, and the need to either use classifier-free guidance or repetition penalty to wrangle the output.
|
||||
* higher temperatures are possible, but are more prone to not adhere to the prompt.
|
||||
|
||||
Traditional samplers for text-gen models can apply to the AR (especially rep/len pen), but more exotic samplers (mirostat, DRY, etc.) don't seem to offer much besides serving as bandaid solutions for a lacking AR.
|
||||
|
||||
Compared to non-autoregressive decoding, I personally feel that autoregressive encoding offers a specific-yet-hard-to-quantify expressive quality that the NAR (and pure NAR solutions) does not offer.
|
||||
|
||||
## The NAR (Non-autoregressive) Model
|
||||
|
||||
The NAR is responsible for generating the remaining RVQ levels of the audio codes for a given output. References to the "outputs from the NAR" refers to the underlying "levels" for a given waveform, as each further levels contributes to the final waveform less significantly than the previous.
|
||||
* Some models may refer to this level as the "fine" level.
|
||||
|
||||
As decoding is done non-autoregressively, the model can process tokens "in place" and have them attended to one another in the past and future, thus speeding up output and allowing for "more accurate" outputs.
|
||||
|
||||
|
@ -44,9 +52,13 @@ One problem exhibited from a NAR is producing arfifacts ("crust") in the final w
|
|||
* `token_dropout_error`: This will randomly nudge a small percentage of tokens from the prior RVQ level to simulate wrong tokens being predicted.
|
||||
* `token_dropout_rate`: This will randomly mask off tokens from the prior RVQ level with a mask token, to try and have the model not-strongly-rely on the given input.
|
||||
|
||||
Sampling from the NAR absolutely necessitates a low temperature or to be greedily sampled, as higher temperatures lead to the aforementioned artifacts in the final waveform.
|
||||
|
||||
Traditional samplers do not seem to offer much effect in the output, as it seems the output from the NAR are decent enough.
|
||||
|
||||
### Pure NAR
|
||||
|
||||
The pure NAR (`NAR-len`) model is a model-type that inferences audio tokens purely non-autoregressively. Despite being called a pure NAR, duration is then inferred by autoregressively decoding for its length (as the AR+NAR model shows that you can mix both types).
|
||||
The pure NAR (`NAR-len`) model is a modality that inferences audio tokens purely non-autoregressively.
|
||||
|
||||
However, having a pure NAR is challenging, as you need to both explicitly provide the duration and provide a "good enough" starting sequence of tokens for the initial sequence.
|
||||
* The former problem is easily "solved" by training a `len` classification task.
|
||||
|
@ -59,14 +71,16 @@ The NAR-len model keeps things simple by:
|
|||
* training with a fixed masking ratio (80% of the tokens are masked and trained to predict the remaining tokens)
|
||||
* [this paper](https://arxiv.org/abs/2406.05478v1) mentions a fixed ratio during training yields better results than randomly picking a masking ratio.
|
||||
* randomly picking a duration ~~is actually very ungood and harms the model during training~~ actually doesn't matter much.
|
||||
* theoretically, it should help later stages in demasking to better rely on the non-masked tokens, but who knows.
|
||||
* not including any specific timestep embedding information
|
||||
* some solutions add in the (sinusoidal position'd) timestep embedding, either on top of the input embeddings, or as some normalization weight around the attention head (before and after).
|
||||
* it does not seem to be necessary what-so-ever to require this, especially training under a fixed masking ratio.
|
||||
* in reality, the model shouldn't really need to reference this anyways, as training NAR RVQ level 0 is simply to refine a noised/masked off sequence of tokens.
|
||||
* in reality, the model shouldn't really need to reference this anyways.
|
||||
* predicting the "duration" (the output audio token window) is kept within the model itself, by autoregressievly inferencing the duration for a given input prompt (text + audio).
|
||||
* the model can already "know" the duration for a given prompt already from an AR RVQ level 0, by predicting when to output the stop token, so it makes sense to re-use the model for this.
|
||||
* the output length is a simple tokenized sequence where each token is a base-10 digit.
|
||||
* it could be in any base, but it's simple to just treat each token ID as a digit, then cast the string to an int.
|
||||
* some checkpoints of the model seems to adhere well to outputting silence at the end if the requested duration exceeds the actual duration.
|
||||
* inferencing is a simple loop that simply takes the best masked-off k tokens per step, and remasks the remaining.
|
||||
|
||||
Because the model already leverages the magic of attention to derive phoneme-alignment, such annotations are still not required (but they probably help with a naive sampler).
|
||||
|
@ -83,7 +97,8 @@ It is ***crucial*** to:
|
|||
* use unfiltered/unprocessed logit scores:
|
||||
* not that crucial, but helps stability.
|
||||
|
||||
It is not required to train a model from scratch to use this modality, as using existing weights works just as well, if not better (as it can piggyback off the original model).
|
||||
It is not required to train a model from scratch to use this modality, as training from existing weights works just as well, if not better (as it can piggyback off the original model).
|
||||
* additional training is still required to help confidence issues and to condition the model to not fall apart for longer durations.
|
||||
|
||||
## Embeddings (and Classifiers)
|
||||
|
||||
|
@ -97,7 +112,7 @@ With attention-based transformers, most embeddings can serve as a token itself a
|
|||
|
||||
Classifiers are the final output head / projection layer that processes the last hidden states of a model into a probability distribution for each token.
|
||||
|
||||
Out of paranoia, each head is split for each macro-task (RVQ level) and an auxiliary head for tasks `stt` and `len`, even though the core half of the model's training was with a single output head.
|
||||
Out of paranoia, each head is split for each macro-task (RVQ level, `stt`, and `len`), even though the core half of the model's training was with a single output head.
|
||||
|
||||
### Text Embeddings
|
||||
|
||||
|
@ -236,6 +251,7 @@ The model can be prompted in creative ways to yield some interesting behaviors:
|
|||
* classifier-free-guidance-aware training does fix this, but this property emerges without it.
|
||||
* prompting with an input text prompt being the transcription of the input audio prompt will have the response follow very closely to the input prompt (despite not doing input=output training).
|
||||
* this should allow for easy transcription editing without much fuss.
|
||||
* the `NAR-len` greatly exhibits this property, although it sometimes does keep any noise in the background.
|
||||
|
||||
# `models/*`
|
||||
|
||||
|
|
|
@ -196,6 +196,7 @@ class TTS():
|
|||
|
||||
input_prompt_length = 0,
|
||||
load_from_artifact = False,
|
||||
nar_len_prefix_length = 0,
|
||||
|
||||
seed = None,
|
||||
out_path=None,
|
||||
|
@ -271,7 +272,7 @@ class TTS():
|
|||
# to-do: add in case for experimental.hf model
|
||||
with torch.autocast("cuda", dtype=self.dtype, enabled=self.amp):
|
||||
if model_len is not None:
|
||||
len_list = model_len( text_list=[phns], proms_list=[prom], task_list=["len"], disable_tqdm=not tqdm, **{"max_steps": 5} ) # don't need more than that
|
||||
len_list = model_len( text_list=[phns], proms_list=[prom], task_list=["len"], disable_tqdm=not tqdm, **{"max_duration": 5} ) # don't need more than that
|
||||
kwargs = {}
|
||||
# nasty hardcode to load a reference file and have that as the input target
|
||||
if load_from_artifact and load_from_artifact.exists():
|
||||
|
@ -283,6 +284,15 @@ class TTS():
|
|||
len_list = [ resp.shape[0] ]
|
||||
|
||||
kwargs["resps_list"] = [ resp[:, :1] ]
|
||||
# kludge experiment
|
||||
elif nar_len_prefix_length > 0:
|
||||
resps_list = model_nar(
|
||||
text_list=[phns], proms_list=[prom], lang_list=[lang], task_list=["tts"],
|
||||
disable_tqdm=not tqdm,
|
||||
use_lora=use_lora,
|
||||
**(sampling_kwargs | {"max_duration": nar_len_prefix_length}),
|
||||
)
|
||||
kwargs["resps_list"] = [ resp if resp.dim() == 1 else resp[:, 0] for resp in resps_list ]
|
||||
|
||||
resps_list = model_nar( text_list=[phns], proms_list=[prom], len_list=len_list, task_list=["tts"],
|
||||
disable_tqdm=not tqdm,
|
||||
|
|
|
@ -264,23 +264,42 @@ class AR_NAR(Base):
|
|||
max_steps = math.floor(max_steps * (end_noise - start_noise))
|
||||
|
||||
len_list = [ clamp(l, min_length, max_length) for l in len_list ]
|
||||
|
||||
|
||||
# if we're denoising from an existing sequence
|
||||
if start_noise > 0.0 and resps_list is not None:
|
||||
# flatten if needed
|
||||
resps_list = [ resps if resps.dim() == 1 else resps[:, 0] for resps in resps_list ]
|
||||
# gen masking ratio
|
||||
noise_p = math.cos( start_noise * math.pi * 0.5 )
|
||||
mask = [ torch.tensor( [ random.random() < noise_p for _ in range( seq_len ) ], dtype=torch.bool, device=device ) for seq_len in len_list ]
|
||||
resps_list = [ torch.where( is_masked, self.stop_token, resps if resps.dim() == 1 else resps[:, 0] ) for is_masked, seq_len, resps in zip( mask, len_list, resps_list ) ]
|
||||
# generate scoring mask (because the above mask will get masked off per the scores, so we do not need to mask beforehand)
|
||||
scores = [ torch.tensor( [ 1.0 if random.random() < noise_p else 0.0 for _ in range( seq_len ) ], dtype=torch.float32, device=device ) for seq_len in len_list ]
|
||||
# deduce that this is a prefix
|
||||
elif resps_list is not None:
|
||||
# number of remaining tokens
|
||||
tokens_to_mask = [ l - resps.shape[0] for resps, l in zip( resps_list, len_list ) ]
|
||||
# pad with masked tokens
|
||||
resps_list = [ torch.concat([ resps if resps.dim() == 1 else resps[:, 0], torch.tensor( [ self.stop_token ] * l, dtype=resps.dtype, device=resps.device ) ]) for resps, l in zip( resps_list, tokens_to_mask ) ]
|
||||
# update scores to ignore the prefix
|
||||
scores = [ torch.concat( [ torch.zeros((resps.shape[0],), dtype=torch.int16, device=device), torch.ones((l), dtype=torch.int16, device=device) ] ) for resps, l in zip( resps_list, tokens_to_mask ) ]
|
||||
# set start noise
|
||||
# only the first because we do not have variable noising at the moment
|
||||
# *technically* the prefix can be a fixed portion for all inputs in a batch, rather than a fixed length
|
||||
# this will set the starting noise_p with the right ratio
|
||||
start_noise = 2 / math.pi * math.acos(resps_list[0].shape[0] / len_list[0])
|
||||
else:
|
||||
# fill with masked tokens (even though they get masked anyways)
|
||||
resps_list = [ torch.ones((seq_len,), dtype=torch.int16, device=device) * self.stop_token for seq_len in len_list ]
|
||||
|
||||
scores = [ torch.zeros((seq_len,), dtype=torch.float32, device=device) for seq_len in len_list ]
|
||||
# fill scores
|
||||
scores = [ torch.ones((seq_len,), dtype=torch.float32, device=device) for seq_len in len_list ]
|
||||
|
||||
quant_levels = [ level for _ in range(batch_size) ]
|
||||
null_text = [ torch.tensor([1, 2], device=device, dtype=torch.int16) for _ in range(batch_size) ]
|
||||
null_prom = [ None for _ in range(batch_size) ]
|
||||
prev_list = resps_list
|
||||
|
||||
for timestep in tqdm(torch.linspace(start_noise, end_noise, max_steps), desc="NAR Masked", disable=disable_tqdm):
|
||||
# update previous list of tokens
|
||||
prev_list = resps_list
|
||||
# ramp down over time
|
||||
annealing = 1.0 - timestep
|
||||
# get noise level, per cosine scheduling
|
||||
|
@ -352,8 +371,6 @@ class AR_NAR(Base):
|
|||
temperature=0.0,
|
||||
**sampling_kwargs,
|
||||
)
|
||||
# update previous list of tokens
|
||||
prev_list = resps_list
|
||||
# get sampled tokens
|
||||
sampled_ids = filtered_sampled.ids
|
||||
# keep unmasked tokens
|
||||
|
|
Loading…
Reference in New Issue
Block a user