From 7cc0250a1a559da90965812fdefcba0d54a59c41 Mon Sep 17 00:00:00 2001
From: mrq <barry.quiggles@protonmail.com>
Date: Fri, 24 Feb 2023 23:10:04 +0000
Subject: [PATCH] added more kill checks, since it only actually did it for the
 first iteration of a loop

---
 tortoise/api.py | 8 +++++++-
 1 file changed, 7 insertions(+), 1 deletion(-)

diff --git a/tortoise/api.py b/tortoise/api.py
index 7ff3621..115cea6 100755
--- a/tortoise/api.py
+++ b/tortoise/api.py
@@ -42,12 +42,15 @@ MODELS = {
     'rlg_diffuser.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/rlg_diffuser.pth',
 }
 
-def tqdm_override(arr, verbose=False, progress=None, desc=None):
+def check_for_kill_signal():
     global STOP_SIGNAL
     if STOP_SIGNAL:
         STOP_SIGNAL = False
         raise Exception("Kill signal detected")
 
+def tqdm_override(arr, verbose=False, progress=None, desc=None):
+    check_for_kill_signal()
+
     if verbose and desc is not None:
         print(desc)
 
@@ -368,6 +371,7 @@ class TextToSpeech:
             # expand / truncate samples to match the common size
             # required, as tensors need to be of the same length
             for chunk in tqdm_override(chunks, verbose=verbose, progress=progress, desc="Computing conditioning latents..."):
+                check_for_kill_signal()
                 chunk = pad_or_truncate(chunk, chunk_size)
                 cond_mel = wav_to_univnet_mel(chunk.to(device), do_normalization=False, device=device)
                 diffusion_conds.append(cond_mel)
@@ -524,6 +528,7 @@ class TextToSpeech:
 
             with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=half_p):
                 for b in tqdm_override(range(num_batches), verbose=verbose, progress=progress, desc="Generating autoregressive samples"):
+                    check_for_kill_signal()
                     codes = self.autoregressive.inference_speech(auto_conditioning, text_tokens,
                                                                  do_sample=True,
                                                                  top_p=top_p,
@@ -565,6 +570,7 @@ class TextToSpeech:
                         desc = f"Computing best candidates using CLVP {((1-cvvp_amount) * 100):2.0f}% and CVVP {(cvvp_amount * 100):2.0f}%"
 
                 for batch in tqdm_override(samples, verbose=verbose, progress=progress, desc=desc):
+                    check_for_kill_signal()
                     for i in range(batch.shape[0]):
                         batch[i] = fix_autoregressive_output(batch[i], stop_mel_token)