tweaks and changes
This commit is contained in:
parent
2fbeacfe92
commit
23fdba0c98
|
@ -9,7 +9,7 @@ The beauty of a transformer, I feel, is that you can easily define any task at i
|
|||
|
||||
The inputs are sequenced in a way that the given task requires automatically, and the outputs are handled as per the class that extends the base model.
|
||||
|
||||
While the original paper called for a separate AR model and a NAR model, and by treating the AR and the NAR as unique tasks, you can actually train a unified model for effectively free, as the internal states of the two should overlap quite a lot.
|
||||
While the original paper called for a separate AR model and a NAR model, and by treating the AR and the NAR as unique tasks, you can actually train a unified model (`AR+NAR`) for effectively free, as the internal states of the two should overlap quite a lot.
|
||||
|
||||
## The AR (Autoregressive) Model
|
||||
|
||||
|
@ -45,7 +45,7 @@ One problem exhibited from a NAR is producing arfifacts ("crust") in the final w
|
|||
|
||||
### Pure NAR
|
||||
|
||||
The pure NAR (`nar-len`) model is a model-type that inferences audio tokens purely non-autoregressively. Despite being called a pure NAR, duration is then inferred by autoregressively decoding for its length (as the AR+NAR model shows that you can mix both types).
|
||||
The pure NAR (`NAR-len`) model is a model-type that inferences audio tokens purely non-autoregressively. Despite being called a pure NAR, duration is then inferred by autoregressively decoding for its length (as the AR+NAR model shows that you can mix both types).
|
||||
|
||||
However, having a pure NAR is challenging, as you need to both explicitly provide the duration and provide a "good enough" starting sequence of tokens for the initial sequence.
|
||||
* The former problem is easily "solved" by training a `len` classification task.
|
||||
|
@ -54,14 +54,22 @@ However, having a pure NAR is challenging, as you need to both explicitly provid
|
|||
* masking to emulate diffusion noising is best working solution, but has a lot of training challenges.
|
||||
* existing solutions like Muse (text to image) and MaskGCT (text to speech) do this
|
||||
|
||||
To-do: fill out this more when it works. Getting this to work is a huge pain.
|
||||
* Some masked transformers do not "inject" any timestep information (Text-To-Image Muse as far as I can tell)
|
||||
* Others "expose" it by applying a timestep embedding after pre/post attention normalization
|
||||
* Except F5-TTS only does this pre for its DiTS, but not UnetT
|
||||
* MaskGCT does it both pre and post
|
||||
* the test trainier actually degrades the output immensely when doing this
|
||||
* I'm sure I've seen a masked transformer not have CFG, but most of them seem to do.
|
||||
* ***Extreme*** care is required.
|
||||
The NAR-len model keeps things simple by:
|
||||
* training with a fixed masking ratio (80% of the tokens are masked and trained to predict the remaining tokens)
|
||||
* [this paper](https://arxiv.org/abs/2406.05478v1) mentions a fixed ratio during training yields better results than randomly picking a masking ratio.
|
||||
* not including any specific timestep embedding information
|
||||
* some solutions add in the (sinusoidal position'd) timestep embedding, either on top of the input embeddings, or as some normalization weight around the attention head (before and after).
|
||||
* it does not seem to be necessary what-so-ever to require this, especially training under a fixed masking ratio.
|
||||
* in reality, the model shouldn't really need to reference this anyways, as training NAR RVQ level 0 is simply to refine a noised/masked off sequence of tokens.
|
||||
* predicting the "duration" (the output audio token window) is kept within the model itself, by autoregressievly inferencing the duration for a given input prompt (text + audio).
|
||||
* the model can already "know" the duration for a given prompt already from an AR RVQ level 0, by predicting when to output the stop token, so it makes sense to re-use the model for this.
|
||||
* the output length is a simple tokenized sequence where each token is a base-10 digit.
|
||||
* it could be in any base, but it's simple to just treat each token ID as a digit, then cast the string to an int.
|
||||
* inferencing is a simple loop that simply takes the best masked-off k tokens per step, and remasks the remaining.
|
||||
|
||||
In theory, demasking for the NAR's RVQ level 0 can also be applied to the remaining RVQ levels to further improve the output from the remaining levels.
|
||||
* this isn't necessary as the model already has a strong enough relationship between the prompt, the prior levels, and the targeted level.
|
||||
* this is technically already offered with `cfg.model.experimental.token_dropout_rate` which mirrors masking, but experimentation has not been done to a large degree.
|
||||
|
||||
## Embeddings (and Classifiers)
|
||||
|
||||
|
|
|
@ -3,45 +3,43 @@
|
|||
Training is very dependent on:
|
||||
* the quality of your dataset.
|
||||
* clean utterances and accurate transcriptions go a long way.
|
||||
* a diverse dataset in prosidy and speakers help a ton.
|
||||
* a diverse dataset in prosody and speakers help a ton.
|
||||
* how much data you have.
|
||||
* training from scratch requires upwards of 15K hours.
|
||||
* training from scratch requires upwards of 15K hours at minimum.
|
||||
* training new languages from the base model simply requires maybe ~2K hours each.
|
||||
* the bandwidth you quantized your audio to, as this affects the how many tokens are processed per step.
|
||||
* the underlying model architecture used.
|
||||
* some models behave better than others for a unified approach, others do not.
|
||||
|
||||
For single GPUs, simply running `python3 -m vall_e.train --yaml="./training/config.yaml`.
|
||||
|
||||
For multiple GPUs, or exotic distributed training:
|
||||
* with `deepspeed` backends, simply running `deepspeed --module vall_e.train --yaml="./training/config.yaml"` should handle the gory details.
|
||||
* with `local` backends, simply run `torchrun --nnodes=1 --nproc-per-node={NUMOFGPUS} -m vall_e.train --yaml="./training/config.yaml"`
|
||||
|
||||
You can enter `save` to save the state at any time, or `quit` to save and quit training.
|
||||
|
||||
The `lr` command will also let you adjust the learning rate on the fly. For example: `lr 1.0e-3` will set the learning rate to `0.001`.
|
||||
|
||||
Some additional flags can be passed as well:
|
||||
* `--eval`: only run the evaluation / validation pass, then exit afterwards.
|
||||
* `--eval-random-text-prompts`: use random text prompts for the evaluation pass, rather than the provided text prompts in the dataset.
|
||||
|
||||
A training paradigm that works for me is:
|
||||
* setting the dataloader to sort by duration, then training one epoch, so the model starts with small utterances then trains to larger ones.
|
||||
* the daring can wait until coherent speech emerges, then move to the next step
|
||||
* some additional training using a shuffled dataloader, as the model will be fixated towards whatever duration range it was trained under.
|
||||
* setting the dataloader to sort by duration, then training until coherent speech emerges, so the model can start with the bulk of learning on small, insignificant utterances, then working its way up to larger ones.
|
||||
* ~80% of the epoch from duratio ranges 1.0seconds to 0.8seconds is good enough, as most of the training from this part is just to train the model to speak at all.
|
||||
* additional training using a shuffled dataloader, as the model will be fixated towards whatever duration range it was trained under.
|
||||
* the remaining bulk is to try and have the model better adhere to the prompt as well.
|
||||
* additional training for sampling per speaker, to better help diversify how well it can perform for a range of speakers, rather than just speaking itself
|
||||
* I don't think this is crucial, but speaker-based sampling seems to be a huge placebo if anything.
|
||||
|
||||
I don't remember the exact numbers off the top of my head, but a good loss/accuracy/gradient norm to look out for when coherent speech emergies are:
|
||||
* loss <3.0
|
||||
* acc >0.7
|
||||
* grad_norm <0.2
|
||||
* I don't think this is crucial, but speaker-based sampling seems to be a placebo if anything.
|
||||
|
||||
Training under `float16` should be fairly simple, but care is required to keep the loss scaling factor above 8K, and probably even 16K.
|
||||
* At the very least for pre-trained models, low enough loss scales will irreparably fry the model, and no amount of training afterwards seems to "fix" it.
|
||||
* The current DeepSpeed configuration should keep the loss scale capped to 32K, but this so far is only validated for pre-trained models.
|
||||
* The current DeepSpeed configuration should keep the loss scale capped to 32K; normal training does not seem to have the loss scale ever want to dip below this at least.
|
||||
* Training under `bfloat16` does not have to worry about this as there's no need for loss scaling, but I feel the model performs better when trained under `float16`+AMP rather than `bfloat16` (with or without AMP).
|
||||
|
||||
When training from scratch, maybe 30% of the time spent training is getting coherent speech, with a loose following of the prompt. The remaining bulk of the work is getting the model to closely-er resemble the input prompt.
|
||||
* an accuracy of at least 50% seems to be where coherent speech emerges.
|
||||
* an accuracy of at least 68% is about where it's a good enough model that adheres to the prompt, but requires quite a lot of work to get there.
|
||||
|
||||
As far as typical hyperparameters go:
|
||||
* as I'm using a batched dataloader, I don't have a uniform size amongst the batches, but I believe my average batch size is between 96 to 128 samples per batch (24-32 samples per GPU for 4 GPUs) per step.
|
||||
* the gradient accumulation factor gets adjusted where I feel is best, where I keep it to 1 (no gradient accumulation) for the first milestone of getting coherent speech, and then ramping it up to 2 then 4 as training further goes on, to try and smooth out the gradients.
|
||||
* more effective samples per update step is technically better, but getting coherent speech as fast as possible is preferable, so prioritizing many updates until then is the goal.
|
||||
* afterwards, reducing the gradient norm is the goal, increasing the amount of samples per update step.
|
||||
* as I primarily use prodigyopt, I don't need to worry about the learning rate. Sure, you can lower the `d_coef` (which the trainer will adjust in lieu of the learning rate itself), but I don't feel like it effects things moreso than just adjusting the gradient accumulation factor.
|
||||
|
||||
With the other "hyperparameters" such as ratios for RVQ levels, tasks, etc:
|
||||
* `rvq_levels_p` to `auto` is fine. The primary level is RVQ level 0, so having it majorly represented is fine.
|
||||
* it might be needed to later prefer a more balanced distribution (such as `[0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7]`) to get rid of any confidence issues in RVQ levels 1+, but I felt naively doing this harms the RVQ 0.
|
||||
* `prompt_similar_p` can be pretty much whatever > `0.5`. I've stuck with either `0.75` or `0.825` to prioritize adhering closely-er to the prompt, but still have random prompts used to help the model interanlly "model" what a speaker should sound like. In theory.
|
||||
|
||||
## Try Me
|
||||
|
||||
To quickly test if a configuration works, you can run `python -m vall_e.models.ar_nar --yaml="./data/config.yaml"`; a small trainer will overfit a provided utterance.
|
||||
|
@ -79,4 +77,18 @@ This script handles the VALL-E specific training code.
|
|||
For the most part, this handles:
|
||||
* feeding the model a batch from the dataloader
|
||||
* performing evaluation / validation when requested
|
||||
* unloading the `emb.qnt` model when its not needed anymore
|
||||
* unloading the `emb.qnt` model when its not needed anymore
|
||||
|
||||
For single GPUs, simply running `python3 -m vall_e.train --yaml="./training/config.yaml`.
|
||||
|
||||
For multiple GPUs, or exotic distributed training:
|
||||
* with `deepspeed` backends, simply running `deepspeed --module vall_e.train --yaml="./training/config.yaml"` should handle the gory details.
|
||||
* with `local` backends, simply run `torchrun --nnodes=1 --nproc-per-node={NUMOFGPUS} -m vall_e.train --yaml="./training/config.yaml"`
|
||||
|
||||
You can enter `save` to save the state at any time, or `quit` to save and quit training.
|
||||
|
||||
The `lr` command will also let you adjust the learning rate on the fly. For example: `lr 1.0e-3` will set the learning rate to `0.001`.
|
||||
|
||||
Some additional flags can be passed as well:
|
||||
* `--eval`: only run the evaluation / validation pass, then exit afterwards.
|
||||
* `--eval-random-text-prompts`: use random text prompts for the evaluation pass, rather than the provided text prompts in the dataset.
|
|
@ -261,8 +261,8 @@ class ModelExperimentalSettings:
|
|||
masking_train_p: float = 0.0 # odds of training with masking
|
||||
masking_train_rvq_levels: list = field(default_factory=lambda: [0,0]) # determines which levels to do mask training on
|
||||
|
||||
masking_ratio_fixed: bool = False
|
||||
ignore_inputs_for_loss: bool = False
|
||||
masking_ratio_fixed: bool = True # this sets the masking ratio to a fixed 80%
|
||||
ignore_inputs_for_loss: bool = True # only calculate the loss on the outputs since thats what matters, as the inputs that do have loss calculated upon affects the loss for the entire sequence
|
||||
|
||||
# classifier-free guidance shit
|
||||
cfg_cond_dropout_p: float = 0.0 # 0.2 # probability to drop out text and audio during training
|
||||
|
|
|
@ -218,27 +218,31 @@ class AR_NAR(Base):
|
|||
if cfg.lora is not None:
|
||||
enable_lora( self, cfg.lora.active_level( level ) if use_lora is None else use_lora )
|
||||
|
||||
"""
|
||||
# to-do: check if gumbel sampling works / helps
|
||||
def log(x, eps = 1e-20):
|
||||
return torch.log(x.clamp(min = eps))
|
||||
|
||||
def gumbel_sample(x, temperature = 1., dim = -1):
|
||||
return ((x / max(temperature, 1e-10)) + -log(-log(torch.zeros_like(x).uniform_(0, 1)))).argmax(dim = dim)
|
||||
"""
|
||||
|
||||
# convert (N)AR specific args
|
||||
sampling_kwargs = convert_kwargs( sampling_kwargs, "ar_" )
|
||||
|
||||
min_length = sampling_kwargs.pop("min_duration", 1)
|
||||
max_length = sampling_kwargs.pop("max_duration", 500)
|
||||
max_steps = sampling_kwargs.get("max_steps", 25)
|
||||
refine_on_stop = sampling_kwargs.get("refine_on_stop", False)
|
||||
entropix_sampling = sampling_kwargs.get("entropix_sampling", False)
|
||||
|
||||
temperature = sampling_kwargs.pop("temperature", 1.0)
|
||||
cfg_strength = sampling_kwargs.get("cfg_strength", 0.0)
|
||||
cfg_strength = sampling_kwargs.get("cfg_strength", 3.0) # this really helps keep audio coherent so far
|
||||
start_noise = sampling_kwargs.get("denoise_start", 0.0)
|
||||
end_noise = sampling_kwargs.get("denoise_end", 1.0)
|
||||
max_steps = math.floor(max_steps * (end_noise - start_noise))
|
||||
|
||||
len_list = [ clamp(l, 1, max_length) for l in len_list ]
|
||||
len_list = [ clamp(l, min_length, max_length) for l in len_list ]
|
||||
|
||||
# if we're denoising from an existing sequence
|
||||
if start_noise > 0.0 and resps_list is not None:
|
||||
|
@ -255,6 +259,7 @@ class AR_NAR(Base):
|
|||
prev_list = resps_list
|
||||
|
||||
for timestep in tqdm(torch.linspace(start_noise, end_noise, max_steps), desc="NAR Masked", disable=disable_tqdm):
|
||||
# ramp down over time
|
||||
annealing = 1.0 - timestep
|
||||
# get noise level, per cosine scheduling
|
||||
noise_p = math.cos( timestep * math.pi * 0.5 )
|
||||
|
|
|
@ -417,7 +417,7 @@ class Base(nn.Module):
|
|||
self.stop_token = self.n_audio_tokens # id 1024
|
||||
self.causal = "ar" in self.capabilities or "len" in self.capabilities
|
||||
self.version = self.config.version if self.config is not None else 5
|
||||
self.causal_size = self.config.experimental.causal_size if self.config is not None else (1 if "ar" in self.capabilities else 0)
|
||||
self.causal_size = self.config.experimental.causal_size if self.config is not None else (1 if self.causal else 0)
|
||||
|
||||
self.arch_type = self.config.arch_type if self.config is not None else "llama"
|
||||
|
||||
|
@ -1221,6 +1221,9 @@ class Base(nn.Module):
|
|||
)
|
||||
"""
|
||||
|
||||
if classifier_level == "AR:0:0":
|
||||
classifier_level = "NAR:0:0"
|
||||
|
||||
embedding = self.resps_emb(
|
||||
input if input.dim() == 1 or quant_level == 0 else input[:, :quant_level],
|
||||
#offset = 0 if classifier_level.startswith("AR:") else 1,
|
||||
|
|
|
@ -154,28 +154,27 @@ def run_eval(engines, eval_name, dl, args=None):
|
|||
if engine.hyper_config.experimental.hf:
|
||||
resps_list = engine( **base_kwargs )
|
||||
elif "len" in engine.hyper_config.capabilities:
|
||||
kwargs = base_kwargs | cfg.evaluation.ar_kwargs
|
||||
kwargs = base_kwargs | cfg.evaluation.kwargs
|
||||
max_steps = kwargs.pop("max_steps", 500)
|
||||
|
||||
if True:
|
||||
if "denoise_start" in kwargs:
|
||||
len_list = [ resp.shape[0] for resp in batch["resps"] ]
|
||||
kwargs["resps_list"] = [ resp[:, :1] for resp in batch["resps"] ]
|
||||
kwargs["denoise_start"] = 0.5
|
||||
else:
|
||||
len_list = engine( max_steps=5, **kwargs )
|
||||
len_list = [ min( l, max_steps ) for l in len_list ]
|
||||
|
||||
kwargs = base_kwargs | cfg.evaluation.nar_kwargs
|
||||
kwargs = base_kwargs | cfg.evaluation.kwargs
|
||||
resps_list = engine( **kwargs, len_list=len_list )
|
||||
else:
|
||||
if "ar" in engine.hyper_config.capabilities:
|
||||
kwargs = base_kwargs | cfg.evaluation.ar_kwargs
|
||||
kwargs = base_kwargs | cfg.evaluation.wargs
|
||||
resps_list = engine( **kwargs )
|
||||
else:
|
||||
resps_list = [ resp[:, 0] for resp in batch["resps"] ]
|
||||
|
||||
if "nar" in engine.hyper_config.capabilities:
|
||||
kwargs = base_kwargs | cfg.evaluation.nar_kwargs
|
||||
kwargs = base_kwargs | cfg.evaluation.kwargs
|
||||
resps_list = engine( **kwargs, resps_list=resps_list )
|
||||
|
||||
process( name, batch, resps_list )
|
||||
|
|
Loading…
Reference in New Issue
Block a user