From be6fab9dcb8a1b317bbe7b1270bc0649ca2405f0 Mon Sep 17 00:00:00 2001 From: mrq Date: Mon, 6 Feb 2023 22:31:06 +0000 Subject: [PATCH] added setting to adjust autoregressive sample batch size --- README.md | 1 + app.py | 14 +++++++++++--- tortoise/api.py | 3 +++ 3 files changed, 15 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index ecc02cc..2923ba6 100755 --- a/README.md +++ b/README.md @@ -145,6 +145,7 @@ Below are settings that override the default launch arguments. Some of these req * `Check for Updates`: checks for updates on page load and notifies in console. Only works if you pulled this repo from a gitea instance. * `Low VRAM`: disables optimizations in TorToiSe that increases VRAM consumption. Suggested if your GPU has under 6GiB. * `Voice Latent Max Chunk Size`: during the voice latents calculation pass, this limits how large, in bytes, a chunk can be. Large values can run into VRAM OOM errors. +* `Sample Batch Size`: sets the batch size when generating autoregressive samples. Bigger batches result in faster compute, at the cost of increased VRAM consumption. Leave to 0 to calculate a "best" fit. * `Concurrency Count`: how many Gradio events the queue can process at once. Leave this over 1 if you want to modify settings in the UI that updates other settings while generating audio clips. Below are an explanation of experimental flags. Messing with these might impact performance, as these are exposed only if you know what you are doing. diff --git a/app.py b/app.py index 4bc5a6a..e137fff 100755 --- a/app.py +++ b/app.py @@ -53,6 +53,7 @@ def generate(text, delimiter, emotion, prompt, voice, mic_audio, seed, candidate 'cond_free_k': 2.0, 'diffusion_temperature': 1.0, 'num_autoregressive_samples': num_autoregressive_samples, + 'sample_batch_size': args.sample_batch_size, 'diffusion_iterations': diffusion_iterations, 'voice_samples': voice_samples, @@ -309,11 +310,12 @@ def check_for_updates(): def update_voices(): return gr.Dropdown.update(choices=os.listdir(os.listdir("./tortoise/voices")) + ["microphone"]) -def export_exec_settings( share, check_for_updates, low_vram, cond_latent_max_chunk_size, concurrency_count ): +def export_exec_settings( share, check_for_updates, low_vram, cond_latent_max_chunk_size, sample_batch_size, concurrency_count ): args.share = share args.low_vram = low_vram args.check_for_updates = check_for_updates args.cond_latent_max_chunk_size = cond_latent_max_chunk_size + args.sample_batch_size = sample_batch_size args.concurrency_count = concurrency_count settings = { @@ -321,6 +323,7 @@ def export_exec_settings( share, check_for_updates, low_vram, cond_latent_max_ch 'low-vram':args.low_vram, 'check-for-updates':args.check_for_updates, 'cond-latent-max-chunk-size': args.cond_latent_max_chunk_size, + 'sample-batch-size': args.sample_batch_size, 'concurrency-count': args.concurrency_count, } @@ -428,6 +431,7 @@ def main(): exec_check_for_updates = gr.Checkbox(label="Check For Updates", value=args.check_for_updates) exec_arg_low_vram = gr.Checkbox(label="Low VRAM", value=args.low_vram) exec_arg_cond_latent_max_chunk_size = gr.Number(label="Voice Latents Max Chunk Size", precision=0, value=args.cond_latent_max_chunk_size) + exec_arg_sample_batch_size = gr.Number(label="Sample Batch Size", precision=0, value=args.sample_batch_size) exec_arg_concurrency_count = gr.Number(label="Concurrency Count", precision=0, value=args.concurrency_count) @@ -435,7 +439,7 @@ def main(): check_updates_now = gr.Button(value="Check for Updates") - exec_inputs = [exec_arg_share, exec_check_for_updates, exec_arg_low_vram, exec_arg_cond_latent_max_chunk_size, exec_arg_concurrency_count] + exec_inputs = [exec_arg_share, exec_check_for_updates, exec_arg_low_vram, exec_arg_cond_latent_max_chunk_size, exec_arg_sample_batch_size, exec_arg_concurrency_count] for i in exec_inputs: i.change( @@ -490,18 +494,22 @@ if __name__ == "__main__": 'check-for-updates': False, 'low-vram': False, 'cond-latent-max-chunk-size': 1000000, + 'sample-batch-size': None, 'concurrency-count': 3, } if os.path.isfile('./config/exec.json'): with open(f'./config/exec.json', 'r', encoding="utf-8") as f: - default_arguments = json.load(f) + overrides = json.load(f) + for k in overrides: + default_arguments[k] = overrides[k] parser = argparse.ArgumentParser() parser.add_argument("--share", action='store_true', default=default_arguments['share'], help="Lets Gradio return a public URL to use anywhere") parser.add_argument("--check-for-updates", action='store_true', default=default_arguments['check-for-updates'], help="Checks for update on startup") parser.add_argument("--low-vram", action='store_true', default=default_arguments['low-vram'], help="Disables some optimizations that increases VRAM usage") parser.add_argument("--cond-latent-max-chunk-size", default=default_arguments['cond-latent-max-chunk-size'], type=int, help="Sets an upper limit to audio chunk size when computing conditioning latents") + parser.add_argument("--sample-batch-size", default=default_arguments['sample-batch-size'], type=int, help="Sets an upper limit to audio chunk size when computing conditioning latents") parser.add_argument("--concurrency-count", type=int, default=default_arguments['concurrency-count'], help="How many Gradio events to process at once") args = parser.parse_args() diff --git a/tortoise/api.py b/tortoise/api.py index dc14ff6..bc2eff8 100755 --- a/tortoise/api.py +++ b/tortoise/api.py @@ -391,6 +391,7 @@ class TextToSpeech: return_deterministic_state=False, # autoregressive generation parameters follow num_autoregressive_samples=512, temperature=.8, length_penalty=1, repetition_penalty=2.0, top_p=.8, max_mel_tokens=500, + sample_batch_size=None, # CVVP parameters follow cvvp_amount=.0, # diffusion generation parameters follow @@ -464,6 +465,8 @@ class TextToSpeech: diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=diffusion_iterations, cond_free=cond_free, cond_free_k=cond_free_k) + self.autoregressive_batch_size = pick_best_batch_size_for_gpu() if sample_batch_size is None else sample_batch_size + with torch.no_grad(): samples = [] num_batches = num_autoregressive_samples // self.autoregressive_batch_size