1
1
forked from mrq/tortoise-tts

added setting to adjust autoregressive sample batch size

This commit is contained in:
mrq 2023-02-06 22:31:06 +00:00
parent d8c88078f3
commit a3c077ba13
3 changed files with 15 additions and 3 deletions

View File

@ -145,6 +145,7 @@ Below are settings that override the default launch arguments. Some of these req
* `Check for Updates`: checks for updates on page load and notifies in console. Only works if you pulled this repo from a gitea instance.
* `Low VRAM`: disables optimizations in TorToiSe that increases VRAM consumption. Suggested if your GPU has under 6GiB.
* `Voice Latent Max Chunk Size`: during the voice latents calculation pass, this limits how large, in bytes, a chunk can be. Large values can run into VRAM OOM errors.
* `Sample Batch Size`: sets the batch size when generating autoregressive samples. Bigger batches result in faster compute, at the cost of increased VRAM consumption. Leave to 0 to calculate a "best" fit.
* `Concurrency Count`: how many Gradio events the queue can process at once. Leave this over 1 if you want to modify settings in the UI that updates other settings while generating audio clips.
Below are an explanation of experimental flags. Messing with these might impact performance, as these are exposed only if you know what you are doing.

14
app.py
View File

@ -53,6 +53,7 @@ def generate(text, delimiter, emotion, prompt, voice, mic_audio, seed, candidate
'cond_free_k': 2.0, 'diffusion_temperature': 1.0,
'num_autoregressive_samples': num_autoregressive_samples,
'sample_batch_size': args.sample_batch_size,
'diffusion_iterations': diffusion_iterations,
'voice_samples': voice_samples,
@ -309,11 +310,12 @@ def check_for_updates():
def update_voices():
return gr.Dropdown.update(choices=os.listdir(os.listdir("./tortoise/voices")) + ["microphone"])
def export_exec_settings( share, check_for_updates, low_vram, cond_latent_max_chunk_size, concurrency_count ):
def export_exec_settings( share, check_for_updates, low_vram, cond_latent_max_chunk_size, sample_batch_size, concurrency_count ):
args.share = share
args.low_vram = low_vram
args.check_for_updates = check_for_updates
args.cond_latent_max_chunk_size = cond_latent_max_chunk_size
args.sample_batch_size = sample_batch_size
args.concurrency_count = concurrency_count
settings = {
@ -321,6 +323,7 @@ def export_exec_settings( share, check_for_updates, low_vram, cond_latent_max_ch
'low-vram':args.low_vram,
'check-for-updates':args.check_for_updates,
'cond-latent-max-chunk-size': args.cond_latent_max_chunk_size,
'sample-batch-size': args.sample_batch_size,
'concurrency-count': args.concurrency_count,
}
@ -428,6 +431,7 @@ def main():
exec_check_for_updates = gr.Checkbox(label="Check For Updates", value=args.check_for_updates)
exec_arg_low_vram = gr.Checkbox(label="Low VRAM", value=args.low_vram)
exec_arg_cond_latent_max_chunk_size = gr.Number(label="Voice Latents Max Chunk Size", precision=0, value=args.cond_latent_max_chunk_size)
exec_arg_sample_batch_size = gr.Number(label="Sample Batch Size", precision=0, value=args.sample_batch_size)
exec_arg_concurrency_count = gr.Number(label="Concurrency Count", precision=0, value=args.concurrency_count)
@ -435,7 +439,7 @@ def main():
check_updates_now = gr.Button(value="Check for Updates")
exec_inputs = [exec_arg_share, exec_check_for_updates, exec_arg_low_vram, exec_arg_cond_latent_max_chunk_size, exec_arg_concurrency_count]
exec_inputs = [exec_arg_share, exec_check_for_updates, exec_arg_low_vram, exec_arg_cond_latent_max_chunk_size, exec_arg_sample_batch_size, exec_arg_concurrency_count]
for i in exec_inputs:
i.change(
@ -490,18 +494,22 @@ if __name__ == "__main__":
'check-for-updates': False,
'low-vram': False,
'cond-latent-max-chunk-size': 1000000,
'sample-batch-size': None,
'concurrency-count': 3,
}
if os.path.isfile('./config/exec.json'):
with open(f'./config/exec.json', 'r', encoding="utf-8") as f:
default_arguments = json.load(f)
overrides = json.load(f)
for k in overrides:
default_arguments[k] = overrides[k]
parser = argparse.ArgumentParser()
parser.add_argument("--share", action='store_true', default=default_arguments['share'], help="Lets Gradio return a public URL to use anywhere")
parser.add_argument("--check-for-updates", action='store_true', default=default_arguments['check-for-updates'], help="Checks for update on startup")
parser.add_argument("--low-vram", action='store_true', default=default_arguments['low-vram'], help="Disables some optimizations that increases VRAM usage")
parser.add_argument("--cond-latent-max-chunk-size", default=default_arguments['cond-latent-max-chunk-size'], type=int, help="Sets an upper limit to audio chunk size when computing conditioning latents")
parser.add_argument("--sample-batch-size", default=default_arguments['sample-batch-size'], type=int, help="Sets an upper limit to audio chunk size when computing conditioning latents")
parser.add_argument("--concurrency-count", type=int, default=default_arguments['concurrency-count'], help="How many Gradio events to process at once")
args = parser.parse_args()

View File

@ -391,6 +391,7 @@ class TextToSpeech:
return_deterministic_state=False,
# autoregressive generation parameters follow
num_autoregressive_samples=512, temperature=.8, length_penalty=1, repetition_penalty=2.0, top_p=.8, max_mel_tokens=500,
sample_batch_size=None,
# CVVP parameters follow
cvvp_amount=.0,
# diffusion generation parameters follow
@ -464,6 +465,8 @@ class TextToSpeech:
diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=diffusion_iterations, cond_free=cond_free, cond_free_k=cond_free_k)
self.autoregressive_batch_size = pick_best_batch_size_for_gpu() if sample_batch_size is None else sample_batch_size
with torch.no_grad():
samples = []
num_batches = num_autoregressive_samples // self.autoregressive_batch_size