From efa556b79348255f8cacd6db21d8c4706dae1206 Mon Sep 17 00:00:00 2001
From: mrq <barry.quiggles@protonmail.com>
Date: Fri, 10 Feb 2023 03:02:09 +0000
Subject: [PATCH] Added new options: "Output Sample Rate", "Output Volume", and
 documentation

---
 README.md       |   2 +
 app.py          | 137 +++++++++++++++++++++++++++++++-----------------
 tortoise/api.py |   2 +-
 3 files changed, 91 insertions(+), 50 deletions(-)

diff --git a/README.md b/README.md
index 858fea7..589bd77 100755
--- a/README.md
+++ b/README.md
@@ -197,6 +197,8 @@ Below are settings that override the default launch arguments. Some of these req
 * `Voice Latent Max Chunk Size`: during the voice latents calculation pass, this limits how large, in bytes, a chunk can be. Large values can run into VRAM OOM errors.
 * `Sample Batch Size`: sets the batch size when generating autoregressive samples. Bigger batches result in faster compute, at the cost of increased VRAM consumption. Leave to 0 to calculate a "best" fit.
 * `Concurrency Count`: how many Gradio events the queue can process at once. Leave this over 1 if you want to modify settings in the UI that updates other settings while generating audio clips.
+* `Output Sample Rate`: the sample rate to save the generated audio as. It provides a bit of slight bump in quality
+* `Output Volume`: adjusts the volume through amplitude scaling
 
 Below are an explanation of experimental flags. Messing with these might impact performance, as these are exposed only if you know what you are doing.
 * `Half-Precision`: (attempts to) hint to PyTorch to auto-cast to float16 (half precision) for compute. Disabled by default, due to it making computations slower.
diff --git a/app.py b/app.py
index c0a0422..631333e 100755
--- a/app.py
+++ b/app.py
@@ -21,6 +21,12 @@ from tortoise.utils.audio import load_audio, load_voice, load_voices
 from tortoise.utils.text import split_and_recombine_text
 
 def generate(text, delimiter, emotion, prompt, voice, mic_audio, seed, candidates, num_autoregressive_samples, diffusion_iterations, temperature, diffusion_sampler, breathing_room, cvvp_weight, experimentals, progress=gr.Progress(track_tqdm=True)):
+    try:
+        tts
+    except NameError:
+        raise gr.Error("TTS is still initializing...")
+
+
     if voice != "microphone":
         voices = [voice]
     else:
@@ -36,7 +42,8 @@ def generate(text, delimiter, emotion, prompt, voice, mic_audio, seed, candidate
         voice_samples, conditioning_latents = load_voice(voice)
 
     if voice_samples is not None:
-        sample_voice = voice_samples[0]
+        sample_voice = voice_samples[0].squeeze().cpu()
+
         conditioning_latents = tts.get_conditioning_latents(voice_samples, return_mels=not args.latents_lean_and_mean, progress=progress, max_chunk_size=args.cond_latent_max_chunk_size)
         if len(conditioning_latents) == 4:
             conditioning_latents = (conditioning_latents[0], conditioning_latents[1], conditioning_latents[2], None)
@@ -54,7 +61,6 @@ def generate(text, delimiter, emotion, prompt, voice, mic_audio, seed, candidate
         print("Requesting weighing against CVVP weight, but voice latents are missing some extra data. Please regenerate your voice latents.")
         cvvp_weight = 0
 
-    start_time = time.time()
 
     settings = {
         'temperature': temperature, 'length_penalty': 1.0, 'repetition_penalty': 2.0,
@@ -86,14 +92,24 @@ def generate(text, delimiter, emotion, prompt, voice, mic_audio, seed, candidate
     else:
         texts = split_and_recombine_text(text)
  
+    start_time = time.time()
  
-    timestamp = int(time.time())
-    outdir = f"./results/{voice}/{timestamp}/"
- 
+    outdir = f"./results/{voice}/{int(start_time)}/"
     os.makedirs(outdir, exist_ok=True)
- 
 
     audio_cache = {}
+
+    resampler = torchaudio.transforms.Resample(
+        tts.output_sample_rate,
+        args.output_sample_rate,
+        lowpass_filter_width=16,
+        rolloff=0.85,
+        resampling_method="kaiser_window",
+        beta=8.555504641634386,
+    ) if tts.output_sample_rate != args.output_sample_rate else None
+
+    volume_adjust = torchaudio.transforms.Vol(gain=args.output_volume, gain_type="amplitude") if args.output_volume != 1 else None
+
     for line, cut_text in enumerate(texts):
         if emotion == "Custom":
             if prompt.strip() != "":
@@ -108,21 +124,27 @@ def generate(text, delimiter, emotion, prompt, voice, mic_audio, seed, candidate
  
         if isinstance(gen, list):
             for j, g in enumerate(gen):
-                audio = g.squeeze(0).cpu()
+                os.makedirs(f'{outdir}/candidate_{j}', exist_ok=True)
                 audio_cache[f"candidate_{j}/result_{line}.wav"] = {
-                    'audio': audio,
+                    'audio': g,
                     'text': cut_text,
                 }
-
-                os.makedirs(f'{outdir}/candidate_{j}', exist_ok=True)
-                torchaudio.save(f'{outdir}/candidate_{j}/result_{line}.wav', audio, tts.output_sample_rate)
         else:
-            audio = gen.squeeze(0).cpu()
             audio_cache[f"result_{line}.wav"] = {
-                'audio': audio,
+                'audio': gen,
                 'text': cut_text,
             }
-            torchaudio.save(f'{outdir}/result_{line}.wav', audio, tts.output_sample_rate)
+
+    for k in audio_cache:
+        audio = audio_cache[k]['audio'].squeeze(0).cpu()
+        if resampler is not None:
+            audio = resampler(audio)
+        if volume_adjust is not None:
+            audio = volume_adjust(audio)
+
+        audio_cache[k]['audio'] = audio
+        torchaudio.save(f'{outdir}/{k}', audio, args.output_sample_rate)
+
  
     output_voice = None
     if len(texts) > 1:
@@ -136,7 +158,7 @@ def generate(text, delimiter, emotion, prompt, voice, mic_audio, seed, candidate
                 audio_clips.append(audio)
             
             audio = torch.cat(audio_clips, dim=-1)
-            torchaudio.save(f'{outdir}/combined_{candidate}.wav', audio, tts.output_sample_rate)
+            torchaudio.save(f'{outdir}/combined_{candidate}.wav', audio, args.output_sample_rate)
 
             audio = audio.squeeze(0).cpu()
             audio_cache[f'combined_{candidate}.wav'] = {
@@ -145,15 +167,15 @@ def generate(text, delimiter, emotion, prompt, voice, mic_audio, seed, candidate
             }
 
             if output_voice is None:
-                output_voice = audio
+                output_voice = f'{outdir}/combined_{candidate}.wav'
+            #    output_voice = audio
     else:
         if isinstance(gen, list):
-            output_voice = gen[0]
+            output_voice = f'{outdir}/candidate_0/result_0.wav'
+            #output_voice = gen[0]
         else:
-            output_voice = gen
-    
-    if output_voice is not None:
-        output_voice = (tts.output_sample_rate, output_voice.numpy())
+            output_voice = f'{outdir}/result_0.wav'
+            #output_voice = gen
 
     info = {
         'text': text,
@@ -188,9 +210,12 @@ def generate(text, delimiter, emotion, prompt, voice, mic_audio, seed, candidate
             metadata = music_tag.load_file(f"{outdir}/{path}")
             metadata['lyrics'] = json.dumps(info) 
             metadata.save()
+
+    #if output_voice is not None:
+    #    output_voice = (args.output_sample_rate, output_voice.numpy())
  
     if sample_voice is not None:
-        sample_voice = (tts.input_sample_rate, sample_voice.squeeze().cpu().numpy())
+        sample_voice = (tts.input_sample_rate, sample_voice.numpy())
  
     print(f"Generation took {info['time']} seconds, saved to '{outdir}'\n")
 
@@ -319,10 +344,13 @@ def check_for_updates():
 
     return False
 
+def reload_tts():
+    tts = setup_tortoise()
+
 def update_voices():
     return gr.Dropdown.update(choices=sorted(os.listdir("./tortoise/voices")) + ["microphone"])
 
-def export_exec_settings( share, listen, check_for_updates, models_from_local_only, low_vram, embed_output_metadata, latents_lean_and_mean, cond_latent_max_chunk_size, sample_batch_size, concurrency_count ):
+def export_exec_settings( share, listen, check_for_updates, models_from_local_only, low_vram, embed_output_metadata, latents_lean_and_mean, cond_latent_max_chunk_size, sample_batch_size, concurrency_count, output_sample_rate, output_volume ):
     args.share = share
     args.listen = listen
     args.low_vram = low_vram
@@ -333,6 +361,8 @@ def export_exec_settings( share, listen, check_for_updates, models_from_local_on
     args.embed_output_metadata = embed_output_metadata
     args.latents_lean_and_mean = latents_lean_and_mean
     args.concurrency_count = concurrency_count
+    args.output_sample_rate = output_sample_rate
+    args.output_volume = output_volume
 
     settings = {
         'share': args.share,
@@ -345,6 +375,8 @@ def export_exec_settings( share, listen, check_for_updates, models_from_local_on
         'embed-output-metadata': args.embed_output_metadata,
         'latents-lean-and-mean': args.latents_lean_and_mean,
         'concurrency-count': args.concurrency_count,
+        'output-sample-rate': args.output_sample_rate,
+        'output-volume': args.output_volume,
     }
 
     with open(f'./config/exec.json', 'w', encoding="utf-8") as f:
@@ -361,7 +393,9 @@ def setup_args():
         'embed-output-metadata': True,
         'latents-lean-and-mean': True,
         'cond-latent-max-chunk-size': 1000000,
-        'concurrency-count': 3,
+        'concurrency-count': 2,
+        'output-sample-rate': 44100,
+        'output-volume': 1,
     }
 
     if os.path.isfile('./config/exec.json'):
@@ -381,6 +415,8 @@ def setup_args():
     parser.add_argument("--cond-latent-max-chunk-size", default=default_arguments['cond-latent-max-chunk-size'], type=int, help="Sets an upper limit to audio chunk size when computing conditioning latents")
     parser.add_argument("--sample-batch-size", default=default_arguments['sample-batch-size'], type=int, help="Sets an upper limit to audio chunk size when computing conditioning latents")
     parser.add_argument("--concurrency-count", type=int, default=default_arguments['concurrency-count'], help="How many Gradio events to process at once")
+    parser.add_argument("--output-sample-rate", type=int, default=default_arguments['output-sample-rate'], help="Sample rate to resample the output to (from 24KHz)")
+    parser.add_argument("--output-volume", type=float, default=default_arguments['output-volume'], help="Adjusts volume of output")
     args = parser.parse_args()
 
     args.embed_output_metadata = not args.no_embed_output_metadata
@@ -392,7 +428,7 @@ def setup_args():
         match = re.findall(r"^(?:(.+?):(\d+))?(\/.+?)?$", args.listen)[0]
 
         args.listen_host = match[0] if match[0] != "" else "127.0.0.1"
-        args.listen_port = match[1] if match[1] != "" else 8000
+        args.listen_port = match[1] if match[1] != "" else None
         args.listen_path = match[2] if match[2] != "" else "/"
 
     if args.listen_port is not None:
@@ -516,34 +552,37 @@ def setup_gradio():
                     )
         with gr.Tab("Settings"):
             with gr.Row():
+                exec_inputs = []
                 with gr.Column():
-                    with gr.Box():
-                        exec_arg_listen = gr.Textbox(label="Listen", value=args.listen, placeholder="127.0.0.1:7860/")
-                        exec_arg_share = gr.Checkbox(label="Public Share Gradio", value=args.share)
-                        exec_arg_check_for_updates = gr.Checkbox(label="Check For Updates", value=args.check_for_updates)
-                        exec_arg_models_from_local_only = gr.Checkbox(label="Only Load Models Locally", value=args.models_from_local_only)
-                        exec_arg_low_vram = gr.Checkbox(label="Low VRAM", value=args.low_vram)
-                        exec_arg_embed_output_metadata = gr.Checkbox(label="Embed Output Metadata", value=args.embed_output_metadata)
-                        exec_arg_latents_lean_and_mean = gr.Checkbox(label="Slimmer Computed Latents", value=args.latents_lean_and_mean)
-                        exec_arg_cond_latent_max_chunk_size = gr.Number(label="Voice Latents Max Chunk Size", precision=0, value=args.cond_latent_max_chunk_size)
-                        exec_arg_sample_batch_size = gr.Number(label="Sample Batch Size", precision=0, value=args.sample_batch_size)
-                        exec_arg_concurrency_count = gr.Number(label="Concurrency Count", precision=0, value=args.concurrency_count)
-
+                    exec_inputs = exec_inputs + [
+                        gr.Textbox(label="Listen", value=args.listen, placeholder="127.0.0.1:7860/"),
+                        gr.Checkbox(label="Public Share Gradio", value=args.share),
+                        gr.Checkbox(label="Check For Updates", value=args.check_for_updates),
+                        gr.Checkbox(label="Only Load Models Locally", value=args.models_from_local_only),
+                        gr.Checkbox(label="Low VRAM", value=args.low_vram),
+                        gr.Checkbox(label="Embed Output Metadata", value=args.embed_output_metadata),
+                        gr.Checkbox(label="Slimmer Computed Latents", value=args.latents_lean_and_mean),
+                    ]
+                    gr.Button(value="Check for Updates").click(check_for_updates)
+                with gr.Column():
+                    exec_inputs = exec_inputs + [
+                        gr.Number(label="Voice Latents Max Chunk Size", precision=0, value=args.cond_latent_max_chunk_size),
+                        gr.Number(label="Sample Batch Size", precision=0, value=args.sample_batch_size),
+                        gr.Number(label="Concurrency Count", precision=0, value=args.concurrency_count),
+                        gr.Number(label="Ouptut Sample Rate", precision=0, value=args.output_sample_rate),
+                        gr.Slider(label="Ouptut Volume", minimum=0, maximum=2, value=args.output_volume),
+                    ]
 
+                for i in exec_inputs:
+                    i.change(
+                        fn=export_exec_settings,
+                        inputs=exec_inputs
+                        )
+                with gr.Column():
                     experimentals = gr.CheckboxGroup(["Half Precision", "Conditioning-Free"], value=["Conditioning-Free"], label="Experimental Flags")
                     cvvp_weight = gr.Slider(value=0, minimum=0, maximum=1, label="CVVP Weight")
 
-                    check_updates_now = gr.Button(value="Check for Updates")
-
-                    exec_inputs = [exec_arg_share, exec_arg_listen, exec_arg_check_for_updates, exec_arg_models_from_local_only, exec_arg_low_vram, exec_arg_embed_output_metadata, exec_arg_latents_lean_and_mean, exec_arg_cond_latent_max_chunk_size, exec_arg_sample_batch_size, exec_arg_concurrency_count]
-
-                    for i in exec_inputs:
-                        i.change(
-                            fn=export_exec_settings,
-                            inputs=exec_inputs
-                        )
-
-                    check_updates_now.click(check_for_updates)
+                    gr.Button(value="Reload TTS").click(reload_tts)
 
         input_settings = [
             text,
@@ -591,7 +630,7 @@ if __name__ == "__main__":
 
     if args.listen_path is not None and args.listen_path != "/":
         import uvicorn
-        uvicorn.run("app:app", host=args.listen_host, port=args.listen_port)
+        uvicorn.run("app:app", host=args.listen_host, port=args.listen_port if not None else 8000)
     else:
         webui = setup_gradio()
         webui.launch(share=args.share, prevent_thread_lock=True, server_name=args.listen_host, server_port=args.listen_port)
diff --git a/tortoise/api.py b/tortoise/api.py
index 43071d2..60dc56b 100755
--- a/tortoise/api.py
+++ b/tortoise/api.py
@@ -562,7 +562,7 @@ class TextToSpeech:
             # results, but will increase memory usage.
             if not self.minor_optimizations:
                 self.autoregressive = self.autoregressive.to(self.device)
-
+            
             if get_device_name() == "dml":
                 text_tokens = text_tokens.cpu()
                 best_results = best_results.cpu()