From 5a5fd9ca87b08d5e8276eca911a164781342b61d Mon Sep 17 00:00:00 2001
From: mrq <mrq@ecker.tech>
Date: Tue, 21 Mar 2023 21:34:26 +0000
Subject: [PATCH] Added option to unsqueeze sample batches after sampling

---
 modules/dlas         | 2 +-
 modules/tortoise-tts | 2 +-
 src/utils.py         | 6 +++++-
 src/webui.py         | 1 +
 4 files changed, 8 insertions(+), 3 deletions(-)

diff --git a/modules/dlas b/modules/dlas
index 7b5e059..a4afad8 160000
--- a/modules/dlas
+++ b/modules/dlas
@@ -1 +1 @@
-Subproject commit 7b5e0592f875772cfed27f00fe16928a503c582a
+Subproject commit a4afad8837404b6d99c0a7da0f4337da6e34fc61
diff --git a/modules/tortoise-tts b/modules/tortoise-tts
index af78e39..0bcdf81 160000
--- a/modules/tortoise-tts
+++ b/modules/tortoise-tts
@@ -1 +1 @@
-Subproject commit af78e3978a381e5c38aa83c6be8a9f09eb6efebf
+Subproject commit 0bcdf81d0444218b4dedaefa5c546d42f36b8130
diff --git a/src/utils.py b/src/utils.py
index 181f1d9..f86d4de 100755
--- a/src/utils.py
+++ b/src/utils.py
@@ -2010,6 +2010,7 @@ def setup_args():
 		'models-from-local-only': False,
 		'low-vram': False,
 		'sample-batch-size': None,
+		'unsqueeze-sample-batches': False,
 		'embed-output-metadata': True,
 		'latents-lean-and-mean': True,
 		'voice-fixer': False, # getting tired of long initialization times in a Colab for downloading a large dataset for it
@@ -2067,6 +2068,7 @@ def setup_args():
 	parser.add_argument("--prune-nonfinal-outputs", default=default_arguments['prune-nonfinal-outputs'], action='store_true', help="Deletes non-final output files on completing a generation")
 	parser.add_argument("--device-override", default=default_arguments['device-override'], help="A device string to override pass through Torch")
 	parser.add_argument("--sample-batch-size", default=default_arguments['sample-batch-size'], type=int, help="Sets how many batches to use during the autoregressive samples pass")
+	parser.add_argument("--unsqueeze_sample_batches", default=default_arguments['unsqueeze_sample_batches'], action='store_true', help="Unsqueezes sample batches to process one by one after sampling")
 	parser.add_argument("--concurrency-count", type=int, default=default_arguments['concurrency-count'], help="How many Gradio events to process at once")
 	parser.add_argument("--autocalculate-voice-chunk-duration-size", type=float, default=default_arguments['autocalculate-voice-chunk-duration-size'], help="Number of seconds to suggest voice chunk size for (for example, 100 seconds of audio at 10 seconds per chunk will suggest 10 chunks)")
 	parser.add_argument("--output-sample-rate", type=int, default=default_arguments['output-sample-rate'], help="Sample rate to resample the output to (from 24KHz)")
@@ -2131,6 +2133,7 @@ def get_default_settings( hypenated=True ):
 		'prune-nonfinal-outputs': args.prune_nonfinal_outputs,
 		'device-override': args.device_override,
 		'sample-batch-size': args.sample_batch_size,
+		'unsqueeze-sample-batches': args.unsqueeze_sample_batches,
 		'embed-output-metadata': args.embed_output_metadata,
 		'latents-lean-and-mean': args.latents_lean_and_mean,
 		'voice-fixer': args.voice_fixer,
@@ -2178,6 +2181,7 @@ def update_args( **kwargs ):
 	args.prune_nonfinal_outputs = settings['prune_nonfinal_outputs']
 	args.device_override = settings['device_override']
 	args.sample_batch_size = settings['sample_batch_size']
+	args.unsqueeze_sample_batches = settings['unsqueeze_sample_batches']
 	args.embed_output_metadata = settings['embed_output_metadata']
 	args.latents_lean_and_mean = settings['latents_lean_and_mean']
 	args.voice_fixer = settings['voice_fixer']
@@ -2344,7 +2348,7 @@ def load_tts( restart=False, autoregressive_model=None, diffusion_model=None, vo
 
 	tts_loading = True
 	print(f"Loading TorToiSe... (AR: {autoregressive_model}, vocoder: {vocoder_model})")
-	tts = TextToSpeech(minor_optimizations=not args.low_vram, autoregressive_model_path=autoregressive_model, diffusion_model_path=diffusion_model, vocoder_model=vocoder_model, tokenizer_json=tokenizer_json)
+	tts = TextToSpeech(minor_optimizations=not args.low_vram, autoregressive_model_path=autoregressive_model, diffusion_model_path=diffusion_model, vocoder_model=vocoder_model, tokenizer_json=tokenizer_json, unsqueeze_sample_batches=args.unsqueeze_sample_batches)
 	tts_loading = False
 
 	get_model_path('dvae.pth')
diff --git a/src/webui.py b/src/webui.py
index cdf3a2c..6cb87b0 100755
--- a/src/webui.py
+++ b/src/webui.py
@@ -570,6 +570,7 @@ def setup_gradio():
 					EXEC_SETTINGS['prune_nonfinal_outputs'] = gr.Checkbox(label="Delete Non-Final Output", value=args.prune_nonfinal_outputs)
 				with gr.Column():
 					EXEC_SETTINGS['sample_batch_size'] = gr.Number(label="Sample Batch Size", precision=0, value=args.sample_batch_size)
+					EXEC_SETTINGS['unsqueeze_sample_batches'] = gr.Checkbox(label="Unsqueeze Sample Batches", value=args.unsqueeze_sample_batches)
 					EXEC_SETTINGS['concurrency_count'] = gr.Number(label="Gradio Concurrency Count", precision=0, value=args.concurrency_count)
 					EXEC_SETTINGS['autocalculate_voice_chunk_duration_size'] = gr.Number(label="Auto-Calculate Voice Chunk Duration (in seconds)", precision=0, value=args.autocalculate_voice_chunk_duration_size)
 					EXEC_SETTINGS['output_volume'] = gr.Slider(label="Output Volume", minimum=0, maximum=2, value=args.output_volume)