From 328deeddae7c9e0b21cb4b8c56df3d84810f916b Mon Sep 17 00:00:00 2001 From: mrq Date: Mon, 6 Feb 2023 23:14:17 -0600 Subject: [PATCH] forgot to auto compute batch size again if set to 0 --- tortoise/api.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tortoise/api.py b/tortoise/api.py index bc2eff8..07f4878 100755 --- a/tortoise/api.py +++ b/tortoise/api.py @@ -228,7 +228,7 @@ class TextToSpeech: """ self.minor_optimizations = minor_optimizations self.models_dir = models_dir - self.autoregressive_batch_size = pick_best_batch_size_for_gpu() if autoregressive_batch_size is None else autoregressive_batch_size + self.autoregressive_batch_size = pick_best_batch_size_for_gpu() if autoregressive_batch_size is None or autoregressive_batch_size == 0 else autoregressive_batch_size self.enable_redaction = enable_redaction self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if self.enable_redaction: @@ -465,7 +465,7 @@ class TextToSpeech: diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=diffusion_iterations, cond_free=cond_free, cond_free_k=cond_free_k) - self.autoregressive_batch_size = pick_best_batch_size_for_gpu() if sample_batch_size is None else sample_batch_size + self.autoregressive_batch_size = pick_best_batch_size_for_gpu() if sample_batch_size is None or sample_batch_size == 0 else sample_batch_size with torch.no_grad(): samples = []