added documentation on how these new sampling parameters are very iffy and you really need to know what you are doing to use them because this is audio generation and not text generation

This commit is contained in:
mrq 2023-09-08 20:43:36 -05:00
parent 14c78bae39
commit b922f35b6b
2 changed files with 9 additions and 5 deletions

View File

@ -124,9 +124,15 @@ To synthesize speech, invoke either (if exported the models): `python -m vall_e
Some additional flags you can pass are: Some additional flags you can pass are:
* `--max-ar-steps`: maximum steps for inferencing through the AR model. Each second is 75 steps. * `--max-ar-steps`: maximum steps for inferencing through the AR model. Each second is 75 steps.
* `--ar-temp`: sampling temperature to use for the AR pass. During experimentation, `0.95` provides the most consistent output.
* `--nar-temp`: sampling temperature to use for the NAR pass. During experimentation, `0.2` provides the most clean output.
* `--device`: device to use (default: `cuda`, examples: `cuda:0`, `cuda:1`, `cpu`) * `--device`: device to use (default: `cuda`, examples: `cuda:0`, `cuda:1`, `cpu`)
* `--ar-temp`: sampling temperature to use for the AR pass. During experimentation, `0.95` provides the most consistent output, but values close to it works file.
* `--nar-temp`: sampling temperature to use for the NAR pass. During experimentation, `0.2` provides clean output, but values upward of `0.6` seems fine too.
And some experimental sampling flags you can use too (your mileage will ***definitely*** vary):
* `--top-p`: limits the sampling pool to top sum of values that equal `P`% probability in the probability distribution.
* `--top-k`: limits the sampling pool to the top `K` values in the probability distribution.
* `--repetition-penalty`: modifies the probability of tokens if they have appeared before. In the context of audio generation, this is a very iffy parameter to use.
* `--length-penalty`: (AR only) modifies the probability of the stop token based on the current sequence length. This is ***very*** finnicky.
## To-Do ## To-Do

View File

@ -411,8 +411,6 @@ class Base(nn.Module):
) )
return logits return logits
# (NAR) return the entire generated response # (NAR) return the entire generated response
if quant_levels is not None: if quant_levels is not None:
@ -427,7 +425,7 @@ class Base(nn.Module):
# perform repetition penalizing # perform repetition penalizing
logits = [ reptition_penalize(logit, previous=resps[:, 0], factor=sampling_repetition_penalty) for logit, resps in zip( logits, resps_list ) ] logits = [ reptition_penalize(logit, previous=resps[:, 0], factor=sampling_repetition_penalty) for logit, resps in zip( logits, resps_list ) ]
# perform length penalizing # (AR) perform length penalizing
if quant_levels is None and self.causal: if quant_levels is None and self.causal:
logits = [ length_penalize(logit, length=l + 1, factor=sampling_length_penalty, token=self.stop_token) for logit, l in zip( logits, map(len, resps_list) ) ] logits = [ length_penalize(logit, length=l + 1, factor=sampling_length_penalty, token=self.stop_token) for logit, l in zip( logits, map(len, resps_list) ) ]