added more kill checks, since it only actually did it for the first iteration of a loop

This commit is contained in:
mrq 2023-02-24 23:10:04 +00:00
parent de46cf7831
commit 7cc0250a1a

View File

@ -42,12 +42,15 @@ MODELS = {
'rlg_diffuser.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/rlg_diffuser.pth', 'rlg_diffuser.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/rlg_diffuser.pth',
} }
def tqdm_override(arr, verbose=False, progress=None, desc=None): def check_for_kill_signal():
global STOP_SIGNAL global STOP_SIGNAL
if STOP_SIGNAL: if STOP_SIGNAL:
STOP_SIGNAL = False STOP_SIGNAL = False
raise Exception("Kill signal detected") raise Exception("Kill signal detected")
def tqdm_override(arr, verbose=False, progress=None, desc=None):
check_for_kill_signal()
if verbose and desc is not None: if verbose and desc is not None:
print(desc) print(desc)
@ -368,6 +371,7 @@ class TextToSpeech:
# expand / truncate samples to match the common size # expand / truncate samples to match the common size
# required, as tensors need to be of the same length # required, as tensors need to be of the same length
for chunk in tqdm_override(chunks, verbose=verbose, progress=progress, desc="Computing conditioning latents..."): for chunk in tqdm_override(chunks, verbose=verbose, progress=progress, desc="Computing conditioning latents..."):
check_for_kill_signal()
chunk = pad_or_truncate(chunk, chunk_size) chunk = pad_or_truncate(chunk, chunk_size)
cond_mel = wav_to_univnet_mel(chunk.to(device), do_normalization=False, device=device) cond_mel = wav_to_univnet_mel(chunk.to(device), do_normalization=False, device=device)
diffusion_conds.append(cond_mel) diffusion_conds.append(cond_mel)
@ -524,6 +528,7 @@ class TextToSpeech:
with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=half_p): with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=half_p):
for b in tqdm_override(range(num_batches), verbose=verbose, progress=progress, desc="Generating autoregressive samples"): for b in tqdm_override(range(num_batches), verbose=verbose, progress=progress, desc="Generating autoregressive samples"):
check_for_kill_signal()
codes = self.autoregressive.inference_speech(auto_conditioning, text_tokens, codes = self.autoregressive.inference_speech(auto_conditioning, text_tokens,
do_sample=True, do_sample=True,
top_p=top_p, top_p=top_p,
@ -565,6 +570,7 @@ class TextToSpeech:
desc = f"Computing best candidates using CLVP {((1-cvvp_amount) * 100):2.0f}% and CVVP {(cvvp_amount * 100):2.0f}%" desc = f"Computing best candidates using CLVP {((1-cvvp_amount) * 100):2.0f}% and CVVP {(cvvp_amount * 100):2.0f}%"
for batch in tqdm_override(samples, verbose=verbose, progress=progress, desc=desc): for batch in tqdm_override(samples, verbose=verbose, progress=progress, desc=desc):
check_for_kill_signal()
for i in range(batch.shape[0]): for i in range(batch.shape[0]):
batch[i] = fix_autoregressive_output(batch[i], stop_mel_token) batch[i] = fix_autoregressive_output(batch[i], stop_mel_token)