From aee08b73074290d8495da4a3cb9158a932f3cb4c Mon Sep 17 00:00:00 2001 From: mrq Date: Sun, 3 Nov 2024 09:58:29 -0600 Subject: [PATCH] changed layerskip float16 training warning (since it didnt seem to fry on my 4xV100 system) --- README.md | 12 ++++++++++-- vall_e/demo.py | 6 ++++++ vall_e/inference.py | 2 ++ vall_e/models/ar_nar.py | 7 +++++++ vall_e/models/base.py | 17 +++++++++-------- vall_e/train.py | 2 +- 6 files changed, 35 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index 281273b..2fae71d 100755 --- a/README.md +++ b/README.md @@ -240,6 +240,8 @@ And some experimental sampling flags you can use too (your mileage will ***defin * `--dry-multiplier`: (AR only) performs DRY sampling, the scalar factor. * `--dry-base`: (AR only) for DRY sampling, the base of the exponent factor. * `--dry-allowed-length`: (AR only) for DRY sampling, the window to perform DRY sampling within. +* `--layer-skip` (AR only) enables early-exit layer skipping if the model is confident enough (for compatible models) +* `--layer-skip-exit-layer`: (AR only) maximum layer to use (for compatbiel models) ### Speech-to-Text @@ -313,12 +315,19 @@ So far, this only allows you to load a different model without needing to restar * [ ] audio streaming - this *technically* can work without any additional architecture changes, just clever tricks with sampling-then-decoding-to-audio. - something similar to HiFiGAN (or the one for TorToiSe) trained on the last hidden states of the AR *might* also enable an alternate way for streaming. +* [ ] speed up inferencing + - KV caching both yields broken output and quadratically slow output, unless I'm doing something grossly wrong. + - A pure HF model is the only way to fix this, but converting the model to one is a bit of a chore. + - Speculative sampling seems overkill for small models (and in reality seems like it's better to just train a larger model). + - Self-speculation through layer-skipping doesn't offer any tangible speedups, sadly. * [ ] replace the phonemizer with something that doesn't depend on espeak * [ ] train the model to handle text => phoneme (without a hit to the rest of the model) * [ ] ...and phonemes => text * [ ] allow raw text as input instead - espeak is nice, but I can only really put my whole trust with phonemizing English. - a small model trained to handle converting text to phonemes might work, but has it's own problems (another model to carry around, as accurate as the dataset it was trained against, requires training for each language... etc). +* [ ] smarter/clever inferencing, such as: + * [ ] "rolling" context, where the last generated sentence is the prefix for the next sentence. * [ ] explore exotic features like: * using a pure text vocab rather than IPA phonemes (as a transformer should be "smart" enough to map text tokens) * interleaving by using summed embedding tokens: @@ -339,8 +348,7 @@ Despite how lightweight it is in comparison to other TTS's I've meddled with, th + `model.experimental.p_rvq_levels: [0,0,0,0,0,0,0,1,2,3,4,5,6,7]` seems to help? * speakers that aren't similar to an audiobook narrator voice has similarity issues due to the majority of training used `path`-based dataloader sampling instead of `speaker`-based (or `group`-based) dataloader sampling. + although LoRAs help a ton for fixing results for a single voice. - + a diverse dataset in prosidy and speaker (such as a corpus sourced from dramatic media like video games) helps a ton. -* On my test system (7900XTX), it seems inferencing quality depends on the moon phase; I don't know if it's a matter of ROCm nuances (since I've always found it to not be up to par with actual CUDA) or `bfloat16` (due to the model being trained under `float16`+AMP) being the culprit, but your mileage *will* vary depending on the system + dtype + sampler settings. + + a diverse dataset in prosidy and speaker (such as a corpus sourced from dramatic media like video games) helps a ton, but still has issues for speakers not similar to any seen speakers. ## Notices and Citations diff --git a/vall_e/demo.py b/vall_e/demo.py index f8e4d7f..61b053a 100644 --- a/vall_e/demo.py +++ b/vall_e/demo.py @@ -144,6 +144,12 @@ def main(): comparison_kwargs["enabled"]["ar_temp"] = 0.666 comparison_kwargs["enabled"]["top_k"] = 27 comparison_kwargs["enabled"]["top_p"] = 0.9 + elif args.comparison == "layerskip": + comparison_kwargs["suffix"] = "layerskip" + comparison_kwargs["titles"] = [f"Without LayerSkip", "With LayerSkip"] + + comparison_kwargs["disabled"]["layer_skip"] = False + comparison_kwargs["enabled"]["layer_skip"] = True elif args.comparison == "ar-temp": current_temp = args.ar_temp other_temp = 1.0 diff --git a/vall_e/inference.py b/vall_e/inference.py index 9365826..5363942 100755 --- a/vall_e/inference.py +++ b/vall_e/inference.py @@ -338,6 +338,8 @@ class TTS(): sampling_min_temperature=min_nar_temp, sampling_top_p=top_p, sampling_top_k=top_k, sampling_min_p=min_p, sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay, + #sampling_layer_skip=layer_skip, + #sampling_layer_skip_exit_layer=layer_skip_exit_layer, disable_tqdm=not tqdm, use_lora=use_lora, diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index ffaf0be..575320f 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -205,6 +205,11 @@ class AR_NAR(Base): prev_list = resps_list + sampling_layer_skip_variables = {} if sampling_layer_skip else None + + if sampling_layer_skip: + sampling_layer_skip_variables["max_layer"] = sampling_layer_skip_exit_layer if sampling_layer_skip_exit_layer >= 0 else self.n_layers + for n in trange( max_levels, desc="NAR", disable=disable_tqdm ): level = prev_list[0].shape[-1] if level >= max_levels + 1: # min(max_levels + 1, self.n_resp_levels): # commented out to experiment with exceeding trained levels @@ -227,6 +232,8 @@ class AR_NAR(Base): output = super().forward( inputs=inputs, quant_levels=quant_levels, + + layer_skip_variables=sampling_layer_skip_variables, ) logits, state = output.logits, output.state diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 5b1ef1e..d8de287 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -1448,7 +1448,7 @@ class Base(nn.Module): kwargs = { "logits_entropy": 0.1, "logits_varentropy": 0.1, - "min_layer": self.n_layers // 2, + "min_layer": self.n_layers // 4, "max_layer": self.n_layers, } @@ -1466,9 +1466,9 @@ class Base(nn.Module): # output projection layer with masking if self.classifier is not None: - x = self.classifier(x) * m + x = self.classifier(x) # * m elif self.classifiers is not None: - logits = self.classifiers(logits, levels = classifier_quant_levels) * m + logits = self.classifiers(logits, levels = classifier_quant_levels) # * m # calculate metrics metrics = calculate_entropix_metrics( logits ) @@ -1528,19 +1528,19 @@ class Base(nn.Module): # output projection layer with masking if self.classifier is not None: - logits = self.classifier(logits) * m + logits = self.classifier(logits) # * m if output.hidden_states: for i, state in enumerate( hidden_states ): - hidden_states[i] = self.classifier(hidden_states[i]) * m + hidden_states[i] = self.classifier(hidden_states[i]) # * m # to-do: piece-wise classification, now that there's a head for text # although again, one single monolithic head would be preferable instead...... elif self.classifiers is not None: - logits = self.classifiers(logits, levels = classifier_quant_levels) * m + logits = self.classifiers(logits, levels = classifier_quant_levels) # * m if hidden_states is not None: for i, state in enumerate( hidden_states ): - hidden_states[i] = self.classifiers(hidden_states[i], levels = classifier_quant_levels) * m + hidden_states[i] = self.classifiers(hidden_states[i], levels = classifier_quant_levels) # * m # Remove padding logits = [ hi[:li] for hi, li in zip(logits, map(len, x_list)) ] @@ -1618,13 +1618,14 @@ class Base(nn.Module): scores = None entropy = None + #logits = [ logit.to(device="cpu", dtype=logit.dtype if logit.dtype != torch.float16 else torch.float32) for logit in logits ] + #logits = [ logit.to(device="cpu") for logit in logits ] # (AR) entropix sampling # we do it before everything to retain logits for the entire sequence (even though it's still better to pass only the last token) if attentions is not None and quant_levels is None: # move to CPU for speedups seq_lens = [ logit.shape[0] for logit in logits ] - logits = [ logit.to(device="cpu", dtype=logit.dtype if logit.dtype != torch.float16 else torch.float32) for logit in logits ] attentions = torch.stack(attentions, dim=1).to(device="cpu") # ( batch, layer, heads, seq_len, seq_len ) res = [ sample_entropix( diff --git a/vall_e/train.py b/vall_e/train.py index 8f449bc..1573a09 100755 --- a/vall_e/train.py +++ b/vall_e/train.py @@ -230,7 +230,7 @@ def train(): """ # pre-training config validation if cfg.model.experimental.layerskip and cfg.trainer.weight_dtype == "float16": - _logger.warning(f"Training with LayerSkip enabled with float16 will result in frying the model. Please use bfloat16.") + _logger.warning(f"Training with LayerSkip enabled with float16 may result in frying the model if the loss scale gets too small (<=8K) or with too large of a de facto batch size (>512 samples).") # train trainer.train(