From b922f35b6bf163734943ac3c0995f089edd6f0ea Mon Sep 17 00:00:00 2001 From: mrq Date: Fri, 8 Sep 2023 20:43:36 -0500 Subject: [PATCH] added documentation on how these new sampling parameters are very iffy and you really need to know what you are doing to use them because this is audio generation and not text generation --- README.md | 10 ++++++++-- vall_e/models/base.py | 4 +--- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 5ed42ae..e4f0884 100755 --- a/README.md +++ b/README.md @@ -124,9 +124,15 @@ To synthesize speech, invoke either (if exported the models): `python -m vall_e Some additional flags you can pass are: * `--max-ar-steps`: maximum steps for inferencing through the AR model. Each second is 75 steps. -* `--ar-temp`: sampling temperature to use for the AR pass. During experimentation, `0.95` provides the most consistent output. -* `--nar-temp`: sampling temperature to use for the NAR pass. During experimentation, `0.2` provides the most clean output. * `--device`: device to use (default: `cuda`, examples: `cuda:0`, `cuda:1`, `cpu`) +* `--ar-temp`: sampling temperature to use for the AR pass. During experimentation, `0.95` provides the most consistent output, but values close to it works file. +* `--nar-temp`: sampling temperature to use for the NAR pass. During experimentation, `0.2` provides clean output, but values upward of `0.6` seems fine too. + +And some experimental sampling flags you can use too (your mileage will ***definitely*** vary): +* `--top-p`: limits the sampling pool to top sum of values that equal `P`% probability in the probability distribution. +* `--top-k`: limits the sampling pool to the top `K` values in the probability distribution. +* `--repetition-penalty`: modifies the probability of tokens if they have appeared before. In the context of audio generation, this is a very iffy parameter to use. +* `--length-penalty`: (AR only) modifies the probability of the stop token based on the current sequence length. This is ***very*** finnicky. ## To-Do diff --git a/vall_e/models/base.py b/vall_e/models/base.py index b40f528..9e6534c 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -411,8 +411,6 @@ class Base(nn.Module): ) return logits - - # (NAR) return the entire generated response if quant_levels is not None: @@ -427,7 +425,7 @@ class Base(nn.Module): # perform repetition penalizing logits = [ reptition_penalize(logit, previous=resps[:, 0], factor=sampling_repetition_penalty) for logit, resps in zip( logits, resps_list ) ] - # perform length penalizing + # (AR) perform length penalizing if quant_levels is None and self.causal: logits = [ length_penalize(logit, length=l + 1, factor=sampling_length_penalty, token=self.stop_token) for logit, l in zip( logits, map(len, resps_list) ) ]