diff --git a/README.md b/README.md index 5ed42ae..e4f0884 100755 --- a/README.md +++ b/README.md @@ -124,9 +124,15 @@ To synthesize speech, invoke either (if exported the models): `python -m vall_e Some additional flags you can pass are: * `--max-ar-steps`: maximum steps for inferencing through the AR model. Each second is 75 steps. -* `--ar-temp`: sampling temperature to use for the AR pass. During experimentation, `0.95` provides the most consistent output. -* `--nar-temp`: sampling temperature to use for the NAR pass. During experimentation, `0.2` provides the most clean output. * `--device`: device to use (default: `cuda`, examples: `cuda:0`, `cuda:1`, `cpu`) +* `--ar-temp`: sampling temperature to use for the AR pass. During experimentation, `0.95` provides the most consistent output, but values close to it works file. +* `--nar-temp`: sampling temperature to use for the NAR pass. During experimentation, `0.2` provides clean output, but values upward of `0.6` seems fine too. + +And some experimental sampling flags you can use too (your mileage will ***definitely*** vary): +* `--top-p`: limits the sampling pool to top sum of values that equal `P`% probability in the probability distribution. +* `--top-k`: limits the sampling pool to the top `K` values in the probability distribution. +* `--repetition-penalty`: modifies the probability of tokens if they have appeared before. In the context of audio generation, this is a very iffy parameter to use. +* `--length-penalty`: (AR only) modifies the probability of the stop token based on the current sequence length. This is ***very*** finnicky. ## To-Do diff --git a/vall_e/models/base.py b/vall_e/models/base.py index b40f528..9e6534c 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -411,8 +411,6 @@ class Base(nn.Module): ) return logits - - # (NAR) return the entire generated response if quant_levels is not None: @@ -427,7 +425,7 @@ class Base(nn.Module): # perform repetition penalizing logits = [ reptition_penalize(logit, previous=resps[:, 0], factor=sampling_repetition_penalty) for logit, resps in zip( logits, resps_list ) ] - # perform length penalizing + # (AR) perform length penalizing if quant_levels is None and self.causal: logits = [ length_penalize(logit, length=l + 1, factor=sampling_length_penalty, token=self.stop_token) for logit, l in zip( logits, map(len, resps_list) ) ]