From f38c479e9b0c2e3f29454cf9aad965673ecd5a23 Mon Sep 17 00:00:00 2001
From: mrq <barry.quiggles@protonmail.com>
Date: Sun, 5 Feb 2023 06:17:51 +0000
Subject: [PATCH] Added multi-line parsing

---
 app.py          | 74 +++++++++++++++++++++++++++++++++++--------------
 tortoise/api.py |  3 +-
 2 files changed, 54 insertions(+), 23 deletions(-)

diff --git a/app.py b/app.py
index 9cb4738..2fb4bf2 100755
--- a/app.py
+++ b/app.py
@@ -8,8 +8,9 @@ import time
 from datetime import datetime
 from tortoise.api import TextToSpeech
 from tortoise.utils.audio import load_audio, load_voice, load_voices
+from tortoise.utils.text import split_and_recombine_text
 
-def generate(text, emotion, prompt, voice, mic_audio, preset, seed, candidates, num_autoregressive_samples, diffusion_iterations, temperature, diffusion_sampler, progress=gr.Progress()):
+def generate(text, delimiter, emotion, prompt, voice, mic_audio, preset, seed, candidates, num_autoregressive_samples, diffusion_iterations, temperature, diffusion_sampler, progress=gr.Progress()):
     if voice != "microphone":
         voices = [voice]
     else:
@@ -58,38 +59,67 @@ def generate(text, emotion, prompt, voice, mic_audio, preset, seed, candidates,
         'progress': progress,
     }
 
-    gen, additionals = tts.tts( text, **settings )
-    seed = additionals[0]
-
-    info = f"{datetime.now()} | Voice: {','.join(voices)} | Text: {text} | Quality: {preset} preset / {num_autoregressive_samples} samples / {diffusion_iterations} iterations | Temperature: {temperature} | Diffusion Sampler: {diffusion_sampler} | Time Taken (s): {time.time()-start_time} | Seed: {seed}\n".encode('utf8')
-    with open("results.log", "w") as f:
-        f.write(info)
+    if delimiter == "\\n":
+        delimiter = "\n"
 
+    if delimiter != "" and delimiter in text:
+        texts = text.split(delimiter)
+    else:
+        texts = split_and_recombine_text(text)
+ 
+ 
     timestamp = int(time.time())
     outdir = f"./results/{voice}/{timestamp}/"
-
+ 
     os.makedirs(outdir, exist_ok=True)
+ 
+    # to-do: store audio to array to avoid having to re-read from disk when combining
+    # to-do: do not rejoin when not splitting lines
+ 
+    for line, cut_text in enumerate(texts):
+        print(f"[{str(line+1)}/{str(len(texts))}] Generating line: {cut_text}")
 
-    with open(os.path.join(outdir, f'input.txt'), 'w') as f:
-        f.write(f"{info}")
+        gen, additionals = tts.tts(cut_text, **settings )
+        seed = additionals[0]
+ 
+        if isinstance(gen, list):
+            for j, g in enumerate(gen):
+                os.makedirs(os.path.join(outdir, f'candidate_{j}'), exist_ok=True)
+                torchaudio.save(os.path.join(outdir, f'candidate_{j}/result_{line}.wav'), g.squeeze(0).cpu(), 24000)
+        else:
+            torchaudio.save(os.path.join(outdir, f'result_{line}.wav'), gen.squeeze(0).cpu(), 24000)
+ 
+    for candidate in range(candidates):
+        audio_clips = []
+        for line in range(len(texts)):
+            if isinstance(gen, list):
+                wav_file = os.path.join(outdir, f'candidate_{candidate}/result_{line}.wav')
+            else:
+                wav_file = os.path.join(outdir, f'result_{line}.wav')
 
-    if isinstance(gen, list):
-        for j, g in enumerate(gen):
-            torchaudio.save(os.path.join(outdir, f'result_{j}.wav'), g.squeeze(0).cpu(), 24000)
-        
-        output_voice = gen[0]
-    else:
-        torchaudio.save(os.path.join(outdir, f'result.wav'), gen.squeeze(0).cpu(), 24000)
-        output_voice = gen
+            audio_clips.append(load_audio(wav_file, 24000))
+        audio_clips = torch.cat(audio_clips, dim=-1)
+        torchaudio.save(os.path.join(outdir, f'combined_{candidate}.wav'), audio_clips, 24000)
+ 
+    info = f"{datetime.now()} | Voice: {','.join(voices)} | Text: {text} | Quality: {preset} preset / {num_autoregressive_samples} samples / {diffusion_iterations} iterations | Temperature: {temperature} | Time Taken (s): {time.time()-start_time} | Seed: {seed}\n"
+    
+    with open(os.path.join(outdir, f'input.txt'), 'w', encoding="utf-8") as f:
+        f.write(info)
 
-    output_voice = (24000, output_voice.squeeze().cpu().numpy())
+    with open("results.log", "w", encoding="utf-8") as f:
+        f.write(info)
 
+    print(f"Saved to '{outdir}'")
+    
+    output_voice = (24000, audio_clips.squeeze().cpu().numpy())
+ 
     if sample_voice is not None:
         sample_voice = (22050, sample_voice.squeeze().cpu().numpy())
-
+ 
+    audio_clips = []
     return (
         sample_voice,
-        output_voice,
+        output_voice, 
         seed
     )
 
@@ -112,6 +142,7 @@ def main():
         with gr.Row():
             with gr.Column():
                 text = gr.Textbox(lines=4, label="Prompt")
+                delimiter = gr.Textbox(lines=1, label="Multi-Line Delimiter", placeholder="\\n")
 
                 emotion = gr.Radio(
                     ["None", "Happy", "Sad", "Angry", "Disgusted", "Arrogant", "Custom"],
@@ -179,6 +210,7 @@ def main():
                 submit_event = submit.click(generate,
                     inputs=[
                         text,
+                        delimiter,
                         emotion,
                         prompt,
                         voice,
diff --git a/tortoise/api.py b/tortoise/api.py
index 643864c..a8bf54f 100755
--- a/tortoise/api.py
+++ b/tortoise/api.py
@@ -6,7 +6,7 @@ from urllib import request
 
 if 'TORTOISE_MODELS_DIR' not in os.environ:
     os.environ['TORTOISE_MODELS_DIR'] = os.path.join(os.path.dirname(os.path.realpath(__file__)), '../models/tortoise/')
-    
+
 if 'TRANSFORMERS_CACHE' not in os.environ:
     os.environ['TRANSFORMERS_CACHE'] = os.path.join(os.path.dirname(os.path.realpath(__file__)), '../models/transformers/')
 
@@ -170,7 +170,6 @@ def do_spectrogram_diffusion(diffusion_model, diffuser, latents, conditioning_la
         noise = torch.randn(output_shape, device=latents.device) * temperature
         
         mel = None
-        print(f"Sampler: {sampler}")
         if sampler == "P":
             mel = diffuser.p_sample_loop(diffusion_model, output_shape, noise=noise,
                                       model_kwargs={'precomputed_aligned_embeddings': precomputed_embeddings},