changed layerskip float16 training warning (since it didnt seem to fry on my 4xV100 system)
This commit is contained in:
parent
3826f9bae4
commit
aee08b7307
12
README.md
12
README.md
|
@ -240,6 +240,8 @@ And some experimental sampling flags you can use too (your mileage will ***defin
|
|||
* `--dry-multiplier`: (AR only) performs DRY sampling, the scalar factor.
|
||||
* `--dry-base`: (AR only) for DRY sampling, the base of the exponent factor.
|
||||
* `--dry-allowed-length`: (AR only) for DRY sampling, the window to perform DRY sampling within.
|
||||
* `--layer-skip` (AR only) enables early-exit layer skipping if the model is confident enough (for compatible models)
|
||||
* `--layer-skip-exit-layer`: (AR only) maximum layer to use (for compatbiel models)
|
||||
|
||||
### Speech-to-Text
|
||||
|
||||
|
@ -313,12 +315,19 @@ So far, this only allows you to load a different model without needing to restar
|
|||
* [ ] audio streaming
|
||||
- this *technically* can work without any additional architecture changes, just clever tricks with sampling-then-decoding-to-audio.
|
||||
- something similar to HiFiGAN (or the one for TorToiSe) trained on the last hidden states of the AR *might* also enable an alternate way for streaming.
|
||||
* [ ] speed up inferencing
|
||||
- KV caching both yields broken output and quadratically slow output, unless I'm doing something grossly wrong.
|
||||
- A pure HF model is the only way to fix this, but converting the model to one is a bit of a chore.
|
||||
- Speculative sampling seems overkill for small models (and in reality seems like it's better to just train a larger model).
|
||||
- Self-speculation through layer-skipping doesn't offer any tangible speedups, sadly.
|
||||
* [ ] replace the phonemizer with something that doesn't depend on espeak
|
||||
* [ ] train the model to handle text => phoneme (without a hit to the rest of the model)
|
||||
* [ ] ...and phonemes => text
|
||||
* [ ] allow raw text as input instead
|
||||
- espeak is nice, but I can only really put my whole trust with phonemizing English.
|
||||
- a small model trained to handle converting text to phonemes might work, but has it's own problems (another model to carry around, as accurate as the dataset it was trained against, requires training for each language... etc).
|
||||
* [ ] smarter/clever inferencing, such as:
|
||||
* [ ] "rolling" context, where the last generated sentence is the prefix for the next sentence.
|
||||
* [ ] explore exotic features like:
|
||||
* using a pure text vocab rather than IPA phonemes (as a transformer should be "smart" enough to map text tokens)
|
||||
* interleaving by using summed embedding tokens:
|
||||
|
@ -339,8 +348,7 @@ Despite how lightweight it is in comparison to other TTS's I've meddled with, th
|
|||
+ `model.experimental.p_rvq_levels: [0,0,0,0,0,0,0,1,2,3,4,5,6,7]` seems to help?
|
||||
* speakers that aren't similar to an audiobook narrator voice has similarity issues due to the majority of training used `path`-based dataloader sampling instead of `speaker`-based (or `group`-based) dataloader sampling.
|
||||
+ although LoRAs help a ton for fixing results for a single voice.
|
||||
+ a diverse dataset in prosidy and speaker (such as a corpus sourced from dramatic media like video games) helps a ton.
|
||||
* On my test system (7900XTX), it seems inferencing quality depends on the moon phase; I don't know if it's a matter of ROCm nuances (since I've always found it to not be up to par with actual CUDA) or `bfloat16` (due to the model being trained under `float16`+AMP) being the culprit, but your mileage *will* vary depending on the system + dtype + sampler settings.
|
||||
+ a diverse dataset in prosidy and speaker (such as a corpus sourced from dramatic media like video games) helps a ton, but still has issues for speakers not similar to any seen speakers.
|
||||
|
||||
## Notices and Citations
|
||||
|
||||
|
|
|
@ -144,6 +144,12 @@ def main():
|
|||
comparison_kwargs["enabled"]["ar_temp"] = 0.666
|
||||
comparison_kwargs["enabled"]["top_k"] = 27
|
||||
comparison_kwargs["enabled"]["top_p"] = 0.9
|
||||
elif args.comparison == "layerskip":
|
||||
comparison_kwargs["suffix"] = "layerskip"
|
||||
comparison_kwargs["titles"] = [f"Without LayerSkip", "With LayerSkip"]
|
||||
|
||||
comparison_kwargs["disabled"]["layer_skip"] = False
|
||||
comparison_kwargs["enabled"]["layer_skip"] = True
|
||||
elif args.comparison == "ar-temp":
|
||||
current_temp = args.ar_temp
|
||||
other_temp = 1.0
|
||||
|
|
|
@ -338,6 +338,8 @@ class TTS():
|
|||
sampling_min_temperature=min_nar_temp,
|
||||
sampling_top_p=top_p, sampling_top_k=top_k, sampling_min_p=min_p,
|
||||
sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay,
|
||||
#sampling_layer_skip=layer_skip,
|
||||
#sampling_layer_skip_exit_layer=layer_skip_exit_layer,
|
||||
|
||||
disable_tqdm=not tqdm,
|
||||
use_lora=use_lora,
|
||||
|
|
|
@ -205,6 +205,11 @@ class AR_NAR(Base):
|
|||
|
||||
prev_list = resps_list
|
||||
|
||||
sampling_layer_skip_variables = {} if sampling_layer_skip else None
|
||||
|
||||
if sampling_layer_skip:
|
||||
sampling_layer_skip_variables["max_layer"] = sampling_layer_skip_exit_layer if sampling_layer_skip_exit_layer >= 0 else self.n_layers
|
||||
|
||||
for n in trange( max_levels, desc="NAR", disable=disable_tqdm ):
|
||||
level = prev_list[0].shape[-1]
|
||||
if level >= max_levels + 1: # min(max_levels + 1, self.n_resp_levels): # commented out to experiment with exceeding trained levels
|
||||
|
@ -227,6 +232,8 @@ class AR_NAR(Base):
|
|||
output = super().forward(
|
||||
inputs=inputs,
|
||||
quant_levels=quant_levels,
|
||||
|
||||
layer_skip_variables=sampling_layer_skip_variables,
|
||||
)
|
||||
logits, state = output.logits, output.state
|
||||
|
||||
|
|
|
@ -1448,7 +1448,7 @@ class Base(nn.Module):
|
|||
kwargs = {
|
||||
"logits_entropy": 0.1,
|
||||
"logits_varentropy": 0.1,
|
||||
"min_layer": self.n_layers // 2,
|
||||
"min_layer": self.n_layers // 4,
|
||||
"max_layer": self.n_layers,
|
||||
}
|
||||
|
||||
|
@ -1466,9 +1466,9 @@ class Base(nn.Module):
|
|||
|
||||
# output projection layer with masking
|
||||
if self.classifier is not None:
|
||||
x = self.classifier(x) * m
|
||||
x = self.classifier(x) # * m
|
||||
elif self.classifiers is not None:
|
||||
logits = self.classifiers(logits, levels = classifier_quant_levels) * m
|
||||
logits = self.classifiers(logits, levels = classifier_quant_levels) # * m
|
||||
|
||||
# calculate metrics
|
||||
metrics = calculate_entropix_metrics( logits )
|
||||
|
@ -1528,19 +1528,19 @@ class Base(nn.Module):
|
|||
|
||||
# output projection layer with masking
|
||||
if self.classifier is not None:
|
||||
logits = self.classifier(logits) * m
|
||||
logits = self.classifier(logits) # * m
|
||||
|
||||
if output.hidden_states:
|
||||
for i, state in enumerate( hidden_states ):
|
||||
hidden_states[i] = self.classifier(hidden_states[i]) * m
|
||||
hidden_states[i] = self.classifier(hidden_states[i]) # * m
|
||||
# to-do: piece-wise classification, now that there's a head for text
|
||||
# although again, one single monolithic head would be preferable instead......
|
||||
elif self.classifiers is not None:
|
||||
logits = self.classifiers(logits, levels = classifier_quant_levels) * m
|
||||
logits = self.classifiers(logits, levels = classifier_quant_levels) # * m
|
||||
|
||||
if hidden_states is not None:
|
||||
for i, state in enumerate( hidden_states ):
|
||||
hidden_states[i] = self.classifiers(hidden_states[i], levels = classifier_quant_levels) * m
|
||||
hidden_states[i] = self.classifiers(hidden_states[i], levels = classifier_quant_levels) # * m
|
||||
|
||||
# Remove padding
|
||||
logits = [ hi[:li] for hi, li in zip(logits, map(len, x_list)) ]
|
||||
|
@ -1618,13 +1618,14 @@ class Base(nn.Module):
|
|||
|
||||
scores = None
|
||||
entropy = None
|
||||
#logits = [ logit.to(device="cpu", dtype=logit.dtype if logit.dtype != torch.float16 else torch.float32) for logit in logits ]
|
||||
#logits = [ logit.to(device="cpu") for logit in logits ]
|
||||
|
||||
# (AR) entropix sampling
|
||||
# we do it before everything to retain logits for the entire sequence (even though it's still better to pass only the last token)
|
||||
if attentions is not None and quant_levels is None:
|
||||
# move to CPU for speedups
|
||||
seq_lens = [ logit.shape[0] for logit in logits ]
|
||||
logits = [ logit.to(device="cpu", dtype=logit.dtype if logit.dtype != torch.float16 else torch.float32) for logit in logits ]
|
||||
attentions = torch.stack(attentions, dim=1).to(device="cpu") # ( batch, layer, heads, seq_len, seq_len )
|
||||
|
||||
res = [ sample_entropix(
|
||||
|
|
|
@ -230,7 +230,7 @@ def train():
|
|||
"""
|
||||
# pre-training config validation
|
||||
if cfg.model.experimental.layerskip and cfg.trainer.weight_dtype == "float16":
|
||||
_logger.warning(f"Training with LayerSkip enabled with float16 will result in frying the model. Please use bfloat16.")
|
||||
_logger.warning(f"Training with LayerSkip enabled with float16 may result in frying the model if the loss scale gets too small (<=8K) or with too large of a de facto batch size (>512 samples).")
|
||||
|
||||
# train
|
||||
trainer.train(
|
||||
|
|
Loading…
Reference in New Issue
Block a user