From 0bcdf81d0444218b4dedaefa5c546d42f36b8130 Mon Sep 17 00:00:00 2001 From: mrq Date: Tue, 21 Mar 2023 21:33:46 +0000 Subject: [PATCH] option to decouple sample batch size from CLVP candidate selection size (currently just unsqueezes the batches) --- tortoise/api.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/tortoise/api.py b/tortoise/api.py index 3aeb5e1..9b091ca 100755 --- a/tortoise/api.py +++ b/tortoise/api.py @@ -267,8 +267,9 @@ class TextToSpeech: def __init__(self, autoregressive_batch_size=None, models_dir=MODELS_DIR, enable_redaction=True, device=None, minor_optimizations=True, + unsqueeze_sample_batches=False, input_sample_rate=22050, output_sample_rate=24000, - autoregressive_model_path=None, diffusion_model_path=None, vocoder_model=None, tokenizer_json=None + autoregressive_model_path=None, diffusion_model_path=None, vocoder_model=None, tokenizer_json=None, ): """ Constructor @@ -289,6 +290,7 @@ class TextToSpeech: self.input_sample_rate = input_sample_rate self.output_sample_rate = output_sample_rate self.minor_optimizations = minor_optimizations + self.unsqueeze_sample_batches = unsqueeze_sample_batches # for clarity, it's simpler to split these up and just predicate them on requesting VRAM-consuming optimizations self.preloaded_tensors = minor_optimizations @@ -697,8 +699,14 @@ class TextToSpeech: if not self.preloaded_tensors: self.autoregressive = migrate_to_device( self.autoregressive, 'cpu' ) - clip_results = [] + if self.unsqueeze_sample_batches: + new_samples = [] + for batch in samples: + for i in range(batch.shape[0]): + new_samples.append(batch[i].unsqueeze(0)) + samples = new_samples + clip_results = [] if auto_conds is not None: auto_conditioning = migrate_to_device( auto_conditioning, self.device )