From 7cc0250a1a559da90965812fdefcba0d54a59c41 Mon Sep 17 00:00:00 2001 From: mrq Date: Fri, 24 Feb 2023 23:10:04 +0000 Subject: [PATCH] added more kill checks, since it only actually did it for the first iteration of a loop --- tortoise/api.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tortoise/api.py b/tortoise/api.py index 7ff3621..115cea6 100755 --- a/tortoise/api.py +++ b/tortoise/api.py @@ -42,12 +42,15 @@ MODELS = { 'rlg_diffuser.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/rlg_diffuser.pth', } -def tqdm_override(arr, verbose=False, progress=None, desc=None): +def check_for_kill_signal(): global STOP_SIGNAL if STOP_SIGNAL: STOP_SIGNAL = False raise Exception("Kill signal detected") +def tqdm_override(arr, verbose=False, progress=None, desc=None): + check_for_kill_signal() + if verbose and desc is not None: print(desc) @@ -368,6 +371,7 @@ class TextToSpeech: # expand / truncate samples to match the common size # required, as tensors need to be of the same length for chunk in tqdm_override(chunks, verbose=verbose, progress=progress, desc="Computing conditioning latents..."): + check_for_kill_signal() chunk = pad_or_truncate(chunk, chunk_size) cond_mel = wav_to_univnet_mel(chunk.to(device), do_normalization=False, device=device) diffusion_conds.append(cond_mel) @@ -524,6 +528,7 @@ class TextToSpeech: with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=half_p): for b in tqdm_override(range(num_batches), verbose=verbose, progress=progress, desc="Generating autoregressive samples"): + check_for_kill_signal() codes = self.autoregressive.inference_speech(auto_conditioning, text_tokens, do_sample=True, top_p=top_p, @@ -565,6 +570,7 @@ class TextToSpeech: desc = f"Computing best candidates using CLVP {((1-cvvp_amount) * 100):2.0f}% and CVVP {(cvvp_amount * 100):2.0f}%" for batch in tqdm_override(samples, verbose=verbose, progress=progress, desc=desc): + check_for_kill_signal() for i in range(batch.shape[0]): batch[i] = fix_autoregressive_output(batch[i], stop_mel_token)