added documentation on how these new sampling parameters are very iffy and you really need to know what you are doing to use them because this is audio generation and not text generation
This commit is contained in:
parent
14c78bae39
commit
b922f35b6b
10
README.md
10
README.md
|
@ -124,9 +124,15 @@ To synthesize speech, invoke either (if exported the models): `python -m vall_e
|
|||
|
||||
Some additional flags you can pass are:
|
||||
* `--max-ar-steps`: maximum steps for inferencing through the AR model. Each second is 75 steps.
|
||||
* `--ar-temp`: sampling temperature to use for the AR pass. During experimentation, `0.95` provides the most consistent output.
|
||||
* `--nar-temp`: sampling temperature to use for the NAR pass. During experimentation, `0.2` provides the most clean output.
|
||||
* `--device`: device to use (default: `cuda`, examples: `cuda:0`, `cuda:1`, `cpu`)
|
||||
* `--ar-temp`: sampling temperature to use for the AR pass. During experimentation, `0.95` provides the most consistent output, but values close to it works file.
|
||||
* `--nar-temp`: sampling temperature to use for the NAR pass. During experimentation, `0.2` provides clean output, but values upward of `0.6` seems fine too.
|
||||
|
||||
And some experimental sampling flags you can use too (your mileage will ***definitely*** vary):
|
||||
* `--top-p`: limits the sampling pool to top sum of values that equal `P`% probability in the probability distribution.
|
||||
* `--top-k`: limits the sampling pool to the top `K` values in the probability distribution.
|
||||
* `--repetition-penalty`: modifies the probability of tokens if they have appeared before. In the context of audio generation, this is a very iffy parameter to use.
|
||||
* `--length-penalty`: (AR only) modifies the probability of the stop token based on the current sequence length. This is ***very*** finnicky.
|
||||
|
||||
## To-Do
|
||||
|
||||
|
|
|
@ -411,8 +411,6 @@ class Base(nn.Module):
|
|||
)
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
|
||||
# (NAR) return the entire generated response
|
||||
if quant_levels is not None:
|
||||
|
@ -427,7 +425,7 @@ class Base(nn.Module):
|
|||
# perform repetition penalizing
|
||||
logits = [ reptition_penalize(logit, previous=resps[:, 0], factor=sampling_repetition_penalty) for logit, resps in zip( logits, resps_list ) ]
|
||||
|
||||
# perform length penalizing
|
||||
# (AR) perform length penalizing
|
||||
if quant_levels is None and self.causal:
|
||||
logits = [ length_penalize(logit, length=l + 1, factor=sampling_length_penalty, token=self.stop_token) for logit, l in zip( logits, map(len, resps_list) ) ]
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user