Fixed STT in the web UI
This commit is contained in:
parent
8927bad7bc
commit
2495a7ef67
|
@ -72,7 +72,11 @@ This class governs the behavior during the evaluation / validation pass during t
|
|||
|
||||
If `cfg.evaluation.size > 0`, then the evaluation / validation passes are triggered every `cfg.evaluation.frequency` iteration steps.
|
||||
|
||||
During evaluation, a separate copy of the training dataset will be sampled and the inputs will be inferenced to generate an output, while during validation, the validation dataset is sampled from instead.
|
||||
During evaluation:
|
||||
* for the `subtrain` evaluation pass, the training dataset is directly sampled through indices, rather than the iterator, to avoid having to duplicate the dataset.
|
||||
* in the future, the samples during this pass should sample around the training dataloader's current position.
|
||||
* for the `val` validation pass, the validation dataset is sampled through the dataloader's iterator.
|
||||
* currently, the validation dataloader's sampler is not stored.
|
||||
|
||||
A total of `cfg.evaluation.size` samples are inferenced in no more than `cfg.evaluation.batch_size`-sized batches (no more than, because batched samplers may return different sized batches).
|
||||
|
||||
|
|
17
docs/data.md
17
docs/data.md
|
@ -7,12 +7,17 @@ Most of these settings live under `cfg.dataset`.
|
|||
## Dataset
|
||||
|
||||
The provided reference model was trained on `?`k hours of audio with a mix of:
|
||||
* LibriTTS-R's entire dataset
|
||||
* `small`+`medium`+`duplicate` portions of LibriVox
|
||||
* Emilia's German, French, and Japanese dataset
|
||||
* 12K hours of a privately sourced corpus of 425 audiobooks
|
||||
* a small portion of Emilia's English dataset
|
||||
* a personal small corpus of transcribed utterances from a selection of video games
|
||||
* 490.151 hours (out of 585 hours) of LibriTTS-R's entire dataset
|
||||
* 8362.304 hours (out of 10270 hours) of `small`+`medium`+`duplicate` portions of LibriLight
|
||||
* 4467.611 hours (out of `?` hours) of Emilia's German, French, and Japanese dataset
|
||||
* 2927.186 hours (out of `?` hours) of a privately sourced corpus of 425 audiobooks
|
||||
* 2364.799 hours (out of `?` hours) of Emilia's English dataset
|
||||
* 54.775 hours of a personal small corpus of transcribed utterances from a selection of video games
|
||||
|
||||
These durations were reported from the training script directly.
|
||||
* Utterances under 3 seconds or above 32 seconds were culled from the duration count.
|
||||
* Metadata was *mostly* derived from the transcription metadata, mostly.
|
||||
* LibriTTS-R's duration metadata was derived from the quantized audio size.
|
||||
|
||||
### Leverage Your Own Dataset
|
||||
|
||||
|
|
|
@ -4,6 +4,6 @@ To export the models, run: `python -m vall_e.export --yaml=./training/config.yam
|
|||
|
||||
This will export the latest checkpoints, for example, under `./training/ckpt/ar+nar-retnet-8/fp32.pth`, to be loaded on any system with PyTorch, and will include additional metadata, such as the symmap used, and training stats.
|
||||
|
||||
Desite being called `fp32.pth`, you can export it to a different precision type with `--dtype=float16|bfloat16|float32`.
|
||||
Desite being called `fp32.sft` or `fp32.pth`, you can export it to a different precision type with `--dtype=float16|bfloat16|float32`.
|
||||
|
||||
You can also export to `safetensors` with `--format=sft`, and `fp32.sft` will be exported instead.
|
|
@ -13,44 +13,51 @@ To synthesize speech: `python -m vall_e <text> <ref_path> <out_path> --yaml=<yam
|
|||
Some additional flags you can pass are:
|
||||
* `--language`: specifies the language for phonemizing the text, and helps guide inferencing when the model is trained against that language.
|
||||
* `--task`: task to perform. Defaults to `tts`, but accepts `stt` for transcriptions.
|
||||
* `--max-ar-steps`: maximum steps for inferencing through the AR model. Each second is 75 steps.
|
||||
* `--max-duration`: maximum token-duration for inferencing through the AR aspect of the model. Every second corresponds to 75 steps.
|
||||
* `--max-steps`: maximum steps for inferencing through the NAR-len aspect of the model.
|
||||
* `--device`: device to use (default: `cuda`, examples: `cuda:0`, `cuda:1`, `cpu`)
|
||||
* `--ar-temp`: sampling temperature to use for the AR pass. During experimentation, `0.95` provides the most consistent output, but values close to it works fine.
|
||||
* `--nar-temp`: sampling temperature to use for the NAR pass. During experimentation, the lower value, the better. Set to `0` to enable greedy sampling.
|
||||
* `--input-prompt-length`: the maximum duration the input prompt can be (~6 seconds is fine, longer durations lead to slower generations for "better" accuracy, as long as the model was trained against such input prompt durations)
|
||||
* `--ar-temperature`: sampling temperature to use for the AR/NAR pass. 0 enables greedy sampling.
|
||||
* For the AR, ~1.0 is *fine*, but lowering the temperature adheres better to the prosody of the input prompt.
|
||||
* For the AR, low temperatures require a repetition penalty to prevent outputs from degenerating.
|
||||
* For the NAR, greedy sampling is best, but can be raised to 0.2.
|
||||
* `--input-prompt-length`: the duration of the input prompt (~6 seconds is fine, longer durations lead to slower generations for "better" accuracy). 0 does not repeat/trim.
|
||||
* If a prompt is shorter than the given duration, it's repeated to the duration size.
|
||||
|
||||
And some experimental sampling flags you can use too (your mileage will ***definitely*** vary, but most of these are bandaids for a bad AR):
|
||||
* `--input-prompt-prefix`: (AR only) treats the input prompt as the initial response prefix, but...
|
||||
* the transcription of the prompt needs to be in the input text prompt.
|
||||
* doesn't perform all that well (I belive the model needs to be trained a bit on this, as `tts-c`).
|
||||
* `--min-ar-temp`: triggers the dynamic temperature pathway, adjusting the temperature based on the confidence of the best token. Acceptable values are between `[0.0, (n)ar-temp)`.
|
||||
* `--min-temperature`: triggers the dynamic temperature pathway, adjusting the temperature based on the confidence of the best token. Acceptable values are between `[0.0, (n)ar-temperature)`.
|
||||
+ This simply uplifts the [original implementation](https://github.com/kalomaze/koboldcpp/blob/dynamic-temp/llama.cpp#L5132) to perform it.
|
||||
+ **!**NOTE**!**: This does not seem to resolve any issues with setting too high/low of a temperature. The right values are yet to be found.
|
||||
* `--top-p`: limits the sampling pool to top sum of values that equal `P`% probability in the probability distribution.
|
||||
* `--top-k`: limits the sampling pool to the top `K` values in the probability distribution.
|
||||
* `--min-p`: only logits above `P`% probability are considered for sampling (or something, I'm still unsure how this differs from top-p).
|
||||
* `--repetition-penalty`: modifies the probability of tokens if they have appeared before. In the context of audio generation, this is a very iffy parameter to use.
|
||||
* `--repetition-penalty-decay`: modifies the above factor applied to scale based on how far away it is in the past sequence.
|
||||
* `--repetition-penalty-decay`: modifies the above factor applied to scale based on how far away it is in the past sequence.
|
||||
* `--length-penalty`: (AR only) modifies the probability of the stop token based on the current sequence length. This is ***very*** finnicky due to the AR already being well correlated with the length.
|
||||
* `--beam-width`: (AR only) specifies the number of branches to search through for beam sampling.
|
||||
+ This is a very naive implementation that's effectively just greedy sampling across `B` spaces.
|
||||
* `--mirostat-tau`: (AR only) the "surprise value" when performing mirostat sampling.
|
||||
+ This simply uplifts the [original implementation](https://github.com/basusourya/mirostat/blob/master/mirostat.py) to perform it.
|
||||
+ **!**NOTE**!**: This is incompatible with beam search sampling (for the meantime at least).
|
||||
* `--mirostat-eta`: (AR only) the "learning rate" during mirostat sampling applied to the maximum surprise.
|
||||
* `--mirostat-eta`: (AR only) the "learning rate" during mirostat sampling applied to the maximum surprise.
|
||||
* `--dry-multiplier`: (AR only) performs DRY sampling, the scalar factor.
|
||||
* `--dry-base`: (AR only) for DRY sampling, the base of the exponent factor.
|
||||
* `--dry-allowed-length`: (AR only) for DRY sampling, the window to perform DRY sampling within.
|
||||
* `--dry-base`: (AR only) for DRY sampling, the base of the exponent factor.
|
||||
* `--dry-allowed-length`: (AR only) for DRY sampling, the window to perform DRY sampling within.
|
||||
* `--layer-skip` enables early-exit layer skipping if the model is confident enough (for compatible models)
|
||||
* `--layer-skip-exit-layer`: maximum layer to use
|
||||
* `--layer-skip-entropy-threshold`: the maximum the logits' entropy (confidence) needs to be before exiting early
|
||||
* `--layer-skip-varentropy-threshold`: the maximum the logits' varentropy (confidence spread) needs to be before exiting early
|
||||
* `--layer-skip-exit-layer`: maximum layer to use
|
||||
* `--layer-skip-entropy-threshold`: the maximum the logits' entropy (confidence) needs to be before exiting early
|
||||
* `--layer-skip-varentropy-threshold`: the maximum the logits' varentropy (confidence spread) needs to be before exiting early
|
||||
* `--refine-on-stop`: (AR only) uses the last steps' logits for the entire final output sequence, rather than the step-by-step iterative sequence.
|
||||
+ This needs experimenting with to see if there's any downside.
|
||||
+ to-do: compare the probability scores with the original output sequence, and pick the best one.
|
||||
|
||||
Some arguments are able to be prefixed with `ar-` and `nar-` to only use that setting for its respective pass. At the moment through the CLI, this includes:
|
||||
* `temperature`
|
||||
|
||||
### Speech-to-Text
|
||||
|
||||
The `ar+nar-tts+stt-llama-8` model has received additional training for a speech-to-text task against EnCodec-encoded audio.
|
||||
The `ar+nar-tts+stt-llama-8` (now the reference model) model has received additional training for a speech-to-text task against EnCodec-encoded audio.
|
||||
|
||||
Currently, the model only transcribes back into the IPA phonemes it was trained against, as an additional model or external program is required to translate the IPA phonemes back into text.
|
||||
* this does make a model that can phonemize text, and unphonemize text, more desirable in the future to replace espeak (having an additional task to handle this requires additional embeddings, output heads, and possible harm to the model as actual text is not a modality the model is trained on).
|
|
@ -56,13 +56,18 @@ However, having a pure NAR is challenging, as you need to both explicitly provid
|
|||
The implemented solution follows a similar paradigm to diffusion, but with masking instead of noise.
|
||||
* incidentally, [this paper](https://arxiv.org/abs/2406.05478) demonstrates this in the use of a NAR transformer for image generation
|
||||
|
||||
The reference model provided has *some* NAR demasking (mock diffusion) aware training to faciliate a pure NAR model, but:
|
||||
* Sampling absolutely requires rep pen, or the output degenerates.
|
||||
* Output isn't so great, as there's artifacting from either an underbaked model or a naive sampler.
|
||||
|
||||
To-do: fill out this more when it works. Getting this to work is a huge pain.
|
||||
* Some masked transformers do not "inject" any timestep information (Text-To-Image Muse as far as I can tell)
|
||||
* Others "expose" it by applying a timestep embedding after pre/post attention normalization
|
||||
* Except F5-TTS only does this pre for its DiTS, but not UnetT
|
||||
* MaskGCT does it both pre and post
|
||||
* the test trainier actually degrades the output immensely when doing this
|
||||
* I'm sure I've seen a masked transformer not have CFG, but most of them seem to do (and all seem to be poorly documentated on specifically how its doing it for my dumb brain)
|
||||
* I'm sure I've seen a masked transformer not have CFG, but most of them seem to do.
|
||||
* This helps the base AR+NAR tasks and provides CFG sampling for such tasks anyways.
|
||||
|
||||
## Embeddings
|
||||
|
||||
|
@ -70,6 +75,8 @@ The "magic" of subjugating a transformer for audio use lies within the ensemble
|
|||
|
||||
While embeddings *can* be tied to the output head, testing showed that the model ***really*** does not like to do this, although my implementation could very well be flawed.
|
||||
|
||||
With attention-based transformers, most embeddings can serve as a token itself and have the attention mechanism attend to it. Theoretically, there should be little to no functional differences between "tokenizing" an embedding, and summing a modifying embedding, but experimentation is needed for this assertion.
|
||||
|
||||
### Text Embeddings
|
||||
|
||||
The input text phonemes (or output for STT) are passed through an embedding head (`text`), similar to how a normal text LLM would. Nothing fancy is required, as it's very straightforward.
|
||||
|
@ -85,10 +92,10 @@ This embedding provides the requested language for the model to be aware of.
|
|||
|
||||
This *mostly* isn't necessary, but VALL-E X's paper mentions needing a token for the language itself, and other solutions like XTTS2 provides a language token as well.
|
||||
|
||||
In practice, this seems to help govern the accent general mannerisms associated with that language. For example, prompting French or German with the language set to `en` will give typical foreigner speech of trying to speak a language they don't know.
|
||||
* Consequently, since this does tie to accents more, ***extreme*** attention is to be paid to the dialects being trained against, instead of naively grouping, say, all of Spansih to one language code.
|
||||
In practice, this seems to help govern the accent / general mannerisms associated with that language. For example, prompting French or German with the language set to `en` will give typical foreigner speech of trying to speak a language they don't know.
|
||||
* Consequently, since this does tie to accents more, ***extreme*** attention is to be paid to the dialects being trained against, instead of naively grouping, say, all of Spanish to one language code.
|
||||
|
||||
This embedding probably helps the model with being able to perform cross-lingual outputs, but I did not do any experimentations on a model without this, as the reference `ar+nar-llama-8` was trained with this from the beginning (and maybe the `ar+nar-retnet-8` experiment).
|
||||
This embedding probably helps the model with being able to perform cross-lingual outputs, but I did not do any experimentations on a model without this, as the reference `ar+nar-llama-8` was trained with this from the beginning with the small Japanese in my dataset anyhow (and maybe the `ar+nar-retnet-8` experiment).
|
||||
|
||||
#### Tone Embedding
|
||||
|
||||
|
@ -113,6 +120,7 @@ Howver, the `resp` requires some extra care, as the model needs to both causally
|
|||
* In other words, embedding level 1 => RVQ level 0, embedding level 2 => RVQ level 1, etc...
|
||||
* I believe this is because the model needs to "know" whether to predict ~~the next token in the sequence, or the token in the same position of the next RVQ level~~ which tokens of a given embedding.
|
||||
* In other words, the AR's RVQ level 0 embedding predicts itself, while the NAR's embeddings predict the next level's embeddings.
|
||||
* This is evident on how RVQ level 0 can be trained causally and in parallel with its own embeddings, rather than having limiting issues when reusing the embedding across the two.
|
||||
* Unfortunately, providing a token for the current/target RVQ level within the input sequence doesn't seem to help? I don't remember if I experimented with this or not, but testing of a "sane" `resp` embedding proved to be unfruitful.
|
||||
|
||||
The `prom` and `resp` are split since, in theory, it helps the model know better what audio to source from, and what audio is part of the output sequence. In theory.
|
||||
|
@ -127,7 +135,7 @@ Finally, the model *may* then sum each embedding level back down to one sequence
|
|||
* It *could* be beneficial to train a model under mixed modes, but requires experimentation.
|
||||
* The reference model was trained originally without summing, then trained with summing.
|
||||
|
||||
Additionally, it's *technically* possible to instead use the embeddings from the core model used to encode the audio, but in theory this may exclude specific features the model has encoded within the embeddings.
|
||||
Additionally, it's *technically* possible to instead use the embeddings from the model used to encode the audio (for example, EnCodec's embeddings), but in theory this may exclude specific features the model has encoded within the embeddings.
|
||||
|
||||
#### RVQ Level Embedding
|
||||
|
||||
|
@ -204,7 +212,7 @@ This task will follow a reverse sequence of `<audio><language><RVQ level><output
|
|||
|
||||
The model can be prompted in creative ways to yield some interesting behaviors:
|
||||
* prompting without an input audio prompt will have the model generate a random voice at the "cost" of some unintelligible utterance at the beginning of the output response (despite doing no promptless training).
|
||||
* finetunes / LoRAs can benefit from this by having input audio promptless synthesis, while opting to have an input audio prompt for guidance.
|
||||
* classifier-free-guidance-aware training does fix this, but this property emerges without it.
|
||||
* prompting with an input text prompt being the transcription of the input audio prompt will have the response follow very closely to the input prompt (despite not doing input=output training).
|
||||
* this should allow for easy transcription editing without much fuss.
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ Most of these sampler functions do what's written on the tin, but for clarity:
|
|||
|
||||
## Samplers
|
||||
|
||||
When sampling, the output logits are picked for sampling according to the current inference mode. For the AR, only the last token (or last `causal_size` tokens) are used for sampling, while the NAR relies on the previous RVQ level's sequence to determine how many tokens to sample in parallel.
|
||||
When sampling, the output logits are picked for sampling according to the current inference mode. For the AR, only the last token (or last `causal_size` tokens) are used for sampling, while the NAR relies on the previous sequence to determine how many tokens to sample in parallel.
|
||||
|
||||
As the model is trained more, low temperatures are preferred over high temperatures for the AR, while greedy sampling is almost always preferred for the NAR.
|
||||
|
||||
|
@ -18,7 +18,8 @@ Greedy sampling is enabled when the sampling temperature is <= 0, where the most
|
|||
|
||||
This function (`reptition_penalize`) applies a penalty to target logits to avoid repetitive output.
|
||||
|
||||
This is implemented by iterating through a list of past tokens, and penalizing that token's probability score by the requested amount.
|
||||
This is implemented by penalizing tokens in the future from repeating the currently iterated token.
|
||||
* This distinction is required to penalize for the NAR, while the AR only penalizes the single token being inferenced.
|
||||
|
||||
An optional value can also be passed to factor in how far away that token is.
|
||||
|
||||
|
|
|
@ -17,6 +17,7 @@ Synthesizing speech is simple:
|
|||
* `Inference`: Button to start generating the audio.
|
||||
* `Basic Settings`: Basic sampler settings for most uses.
|
||||
* `Sampler Settings`: Advanced sampler settings that are common for most text LLMs, but needs experimentation.
|
||||
* `Experimental Settings`: Settings used for testing. `cfg.experimental=True` enables this tab.
|
||||
|
||||
All the additional knobs have a description that can be correlated to the inferencing CLI flags.
|
||||
|
||||
|
@ -30,4 +31,4 @@ In the future, this *should* contain the necessary niceties to process raw audio
|
|||
|
||||
## Settings
|
||||
|
||||
So far, this only allows you to load a different model without needing to restart. The previous model should seamlessly unload, and the new one will load in place.
|
||||
So far, this only allows you to load a different model under a different dtype, device, and/or attention mechanism. without needing to restart. The previous model should seamlessly unload, and the new one will load in place.
|
|
@ -587,6 +587,8 @@ def _load_paths_from_metadata(group_name, type="training", validate=False):
|
|||
phones = entry['phones'] if "phones" in entry else 0
|
||||
duration = entry['duration'] if "duration" in entry else 0
|
||||
|
||||
#print( id, duration )
|
||||
|
||||
# add to duration bucket
|
||||
k = key(id, entry)
|
||||
if type not in _durations_map:
|
||||
|
@ -1579,7 +1581,7 @@ def create_dataset_metadata( skip_existing=False ):
|
|||
|
||||
utterance_metadata = process_artifact_metadata( artifact )
|
||||
# to-do: derive duration from codes if duration is malformed because this happened to me with LibriTTS-R
|
||||
#utterance_metadata["duration"] = qnt.shape[0] / cfg.dataset.frames_per_second
|
||||
utterance_metadata["duration"] = qnt.shape[0] / cfg.dataset.frames_per_second
|
||||
|
||||
for k, v in utterance_metadata.items():
|
||||
metadata[id][k] = v
|
||||
|
|
|
@ -222,6 +222,8 @@ class AR_NAR(Base):
|
|||
|
||||
max_length = sampling_kwargs.pop("max_duration", 500)
|
||||
max_steps = sampling_kwargs.get("max_steps", 25)
|
||||
refine_on_stop = sampling_kwargs.get("refine_on_stop", False)
|
||||
entropix_sampling = sampling_kwargs.get("entropix_sampling", False)
|
||||
|
||||
temperature = sampling_kwargs.pop("temperature", 1.0)
|
||||
cfg_strength = sampling_kwargs.get("cfg_strength", 0.0)
|
||||
|
@ -299,7 +301,7 @@ class AR_NAR(Base):
|
|||
prev_list=prev_list,
|
||||
quant_levels=quant_levels,
|
||||
|
||||
temperature=temperature * (steps_until_x0 / max_steps) ,
|
||||
temperature=temperature * (steps_until_x0 / max_steps),
|
||||
**sampling_kwargs,
|
||||
)
|
||||
|
||||
|
@ -317,13 +319,67 @@ class AR_NAR(Base):
|
|||
|
||||
# sample with gumbelnoise
|
||||
# This actually lobotomizes things
|
||||
#sampled_ids = [ gumbel_sample( logits, temperature=temperature, dim=-1 ) for logits in filtered_sampled.logits[0] ]
|
||||
#sampled_ids = [ gumbel_sample( logits, temperature=temperature * (steps_until_x0 / max_steps), dim=-1 ) for logits in filtered_sampled.logits[0] ]
|
||||
sampled_ids = filtered_sampled[0]
|
||||
|
||||
# keep unmasked tokens
|
||||
resps_list = [ torch.where( masked, input_ids, resps ) for masked, input_ids, resps in zip( is_masked, sampled_ids, resps_list ) ]
|
||||
# update scores (conjugated to put the worst scores at the top)
|
||||
scores = [ 1.0 - torch.tensor([score for score in scores], device=device) for scores in filtered_sampled.scores ]
|
||||
scores = [ 1.0 - torch.tensor([score for score in scores], device=device) for scores in unfiltered_sampled.scores ]
|
||||
|
||||
# refinement step
|
||||
if refine_on_stop:
|
||||
inputs = super().inputs(
|
||||
text_list=text_list,
|
||||
proms_list=proms_list,
|
||||
resps_list=resps_list,
|
||||
lang_list=lang_list,
|
||||
tone_list=tone_list,
|
||||
quant_levels=quant_levels,
|
||||
)
|
||||
output = super().forward(
|
||||
inputs=inputs,
|
||||
quant_levels=quant_levels,
|
||||
#layer_skip_variables=sampling_layer_skip_variables,
|
||||
)
|
||||
|
||||
logits = output.logits
|
||||
|
||||
if cfg_strength > 0:
|
||||
null_inputs = super().inputs(
|
||||
text_list=null_text,
|
||||
proms_list=null_prom,
|
||||
resps_list=resps_list,
|
||||
lang_list=lang_list,
|
||||
tone_list=tone_list,
|
||||
quant_levels=quant_levels,
|
||||
)
|
||||
null_output = super().forward(
|
||||
inputs=null_inputs,
|
||||
quant_levels=quant_levels,
|
||||
#layer_skip_variables=sampling_layer_skip_variables,
|
||||
)
|
||||
for seq_len, logit, null_logit in zip(len_list, output.logits, null_output.logits):
|
||||
logit[-seq_len:] = null_logit[-seq_len:] + ( logit[-seq_len:] - null_logit[-seq_len:] ) * cfg_strength
|
||||
|
||||
sampled = super().sample(
|
||||
logits=logits,
|
||||
prev_list=[ resps_list[i] if task not in text_task else text_list[i] for i, task in enumerate( task_list ) ],
|
||||
**(sampling_kwargs | {"attentions": output.attentions if entropix_sampling else None}),
|
||||
)
|
||||
|
||||
# remove stop token
|
||||
resps_list = [self._prune(r, self.stop_token) for i, r in enumerate(resps_list)]
|
||||
|
||||
# get how much we need to slice from the end
|
||||
slice_lengths = [ sequence.shape[-1] for sequence in resps_list ]
|
||||
# -1 for the stop token
|
||||
logits = [ logit[-length-1:-1] for logit, length in zip(logits, slice_lengths) ]
|
||||
# greedy sample from the sequence
|
||||
refined_list = [ logit.argmax(dim=-1) for logit in logits ]
|
||||
# to-do: compare scores
|
||||
# set the "refined" list as the output
|
||||
resps_list = refined_list
|
||||
|
||||
if cfg.experimental and max_steps > 0:
|
||||
print( timestep, steps_until_x0, noise_p, resps_list, scores )
|
||||
|
|
|
@ -13,7 +13,7 @@ from .utils import clamp
|
|||
# Simple filter to modify a token's probability if it shows up in the past
|
||||
# `one_time` will only apply the penalty once
|
||||
# `decay` is a factor that will exponentially apply to how far away it is
|
||||
def reptition_penalize( logits, previous=None, factor=1.0, decay=0.0, one_time=True, limit=0 ):
|
||||
def reptition_penalize( logits, previous=None, factor=1.0, decay=0.0, one_time=True, limit=None ):
|
||||
if factor == 1.0 or previous is None:
|
||||
return logits
|
||||
|
||||
|
@ -38,6 +38,9 @@ def reptition_penalize( logits, previous=None, factor=1.0, decay=0.0, one_time=T
|
|||
if is_nar and i < logits.shape[0]:
|
||||
start = i + 1
|
||||
|
||||
if limit:
|
||||
end = i + limit
|
||||
|
||||
logits[start:end, token] /= factor * (distance ** decay)
|
||||
|
||||
# add to set if we care about it
|
||||
|
|
|
@ -306,19 +306,12 @@ def do_inference_stt( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
|||
|
||||
parser = argparse.ArgumentParser(allow_abbrev=False, add_help=False)
|
||||
# I'm very sure I can procedurally generate this list
|
||||
parser.add_argument("--text", type=str, default=kwargs["text"])
|
||||
parser.add_argument("--task", type=str, default="tts")
|
||||
parser.add_argument("--references", type=str, default=kwargs["reference"])
|
||||
parser.add_argument("--max-duration", type=int, default=0)
|
||||
parser.add_argument("--language", type=str, default=kwargs["language"])
|
||||
parser.add_argument("--input-prompt-length", type=float, default=kwargs["input-prompt-length"])
|
||||
parser.add_argument("--input-prompt-prefix", action='store_true', default=kwargs["input-prompt-prefix"])
|
||||
parser.add_argument("--max-duration", type=int, default=int(kwargs["max-duration"]*cfg.dataset.frames_per_second))
|
||||
parser.add_argument("--max-levels", type=int, default=kwargs["max-levels"])
|
||||
parser.add_argument("--ar-temperature", type=float, default=kwargs["ar-temperature"])
|
||||
parser.add_argument("--nar-temperature", type=float, default=kwargs["nar-temperature"])
|
||||
parser.add_argument("--min-ar-temperature", type=float, default=kwargs["min-ar-temperature"])
|
||||
parser.add_argument("--min-nar-temperature", type=float, default=kwargs["min-nar-temperature"])
|
||||
parser.add_argument("--prefix-silence", type=float, default=kwargs["prefix-silence"])
|
||||
parser.add_argument("--top-p", type=float, default=kwargs["top-p"])
|
||||
parser.add_argument("--top-k", type=int, default=kwargs["top-k"])
|
||||
parser.add_argument("--min-p", type=float, default=kwargs["min-p"])
|
||||
|
@ -331,16 +324,8 @@ def do_inference_stt( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
|||
parser.add_argument("--dry-multiplier", type=float, default=kwargs["dry-multiplier"])
|
||||
parser.add_argument("--dry-base", type=float, default=kwargs["dry-base"])
|
||||
parser.add_argument("--dry-allowed-length", type=int, default=kwargs["dry-allowed-length"])
|
||||
parser.add_argument("--entropix-sampling", action="store_true")
|
||||
parser.add_argument("--layer-skip", action="store_true")
|
||||
parser.add_argument("--layer-skip-exit-layer", type=int, default=kwargs["layer-skip-exit-layer"])
|
||||
parser.add_argument("--layer-skip-entropy-threshold", type=int, default=kwargs["layer-skip-entropy-threshold"])
|
||||
parser.add_argument("--layer-skip-varentropy-threshold", type=int, default=kwargs["layer-skip-varentropy-threshold"])
|
||||
parser.add_argument("--refine-on-stop", action="store_true")
|
||||
parser.add_argument("--cfg-strength", type=float, default=kwargs["cfg-strength"])
|
||||
args, unknown = parser.parse_known_args()
|
||||
|
||||
|
||||
"""
|
||||
if not args.references:
|
||||
raise Exception("No reference audio provided.")
|
||||
|
@ -361,21 +346,14 @@ def do_inference_stt( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
|||
|
||||
sampling_kwargs = dict(
|
||||
max_duration=args.max_duration,
|
||||
ar_temperature=args.ar_temperature, nar_temperature=args.nar_temperature,
|
||||
min_ar_temperature=args.min_ar_temperature, min_nar_temperature=args.min_nar_temperature,
|
||||
ar_temperature=args.ar_temperature,
|
||||
min_ar_temperature=args.min_ar_temperature,
|
||||
top_p=args.top_p, top_k=args.top_k, min_p=args.min_p,
|
||||
repetition_penalty=args.repetition_penalty, repetition_penalty_decay=args.repetition_penalty_decay,
|
||||
length_penalty=args.length_penalty,
|
||||
beam_width=args.beam_width,
|
||||
mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta,
|
||||
dry_multiplier=args.dry_multiplier, dry_base=args.dry_base, dry_allowed_length=args.dry_allowed_length,
|
||||
entropix_sampling=args.entropix_sampling,
|
||||
layer_skip=args.layer_skip,
|
||||
layer_skip_exit_layer=args.layer_skip_exit_layer,
|
||||
layer_skip_entropy_threshold=args.layer_skip_entropy_threshold,
|
||||
layer_skip_varentropy_threshold=args.layer_skip_varentropy_threshold,
|
||||
refine_on_stop=args.refine_on_stop,
|
||||
denoise_start=args.denoise_start,
|
||||
)
|
||||
|
||||
gr.Info("Inferencing...")
|
||||
|
|
Loading…
Reference in New Issue
Block a user