forked from mrq/tortoise-tts
added more kill checks, since it only actually did it for the first iteration of a loop
This commit is contained in:
parent
de46cf7831
commit
7cc0250a1a
|
@ -42,12 +42,15 @@ MODELS = {
|
|||
'rlg_diffuser.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/rlg_diffuser.pth',
|
||||
}
|
||||
|
||||
def tqdm_override(arr, verbose=False, progress=None, desc=None):
|
||||
def check_for_kill_signal():
|
||||
global STOP_SIGNAL
|
||||
if STOP_SIGNAL:
|
||||
STOP_SIGNAL = False
|
||||
raise Exception("Kill signal detected")
|
||||
|
||||
def tqdm_override(arr, verbose=False, progress=None, desc=None):
|
||||
check_for_kill_signal()
|
||||
|
||||
if verbose and desc is not None:
|
||||
print(desc)
|
||||
|
||||
|
@ -368,6 +371,7 @@ class TextToSpeech:
|
|||
# expand / truncate samples to match the common size
|
||||
# required, as tensors need to be of the same length
|
||||
for chunk in tqdm_override(chunks, verbose=verbose, progress=progress, desc="Computing conditioning latents..."):
|
||||
check_for_kill_signal()
|
||||
chunk = pad_or_truncate(chunk, chunk_size)
|
||||
cond_mel = wav_to_univnet_mel(chunk.to(device), do_normalization=False, device=device)
|
||||
diffusion_conds.append(cond_mel)
|
||||
|
@ -524,6 +528,7 @@ class TextToSpeech:
|
|||
|
||||
with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=half_p):
|
||||
for b in tqdm_override(range(num_batches), verbose=verbose, progress=progress, desc="Generating autoregressive samples"):
|
||||
check_for_kill_signal()
|
||||
codes = self.autoregressive.inference_speech(auto_conditioning, text_tokens,
|
||||
do_sample=True,
|
||||
top_p=top_p,
|
||||
|
@ -565,6 +570,7 @@ class TextToSpeech:
|
|||
desc = f"Computing best candidates using CLVP {((1-cvvp_amount) * 100):2.0f}% and CVVP {(cvvp_amount * 100):2.0f}%"
|
||||
|
||||
for batch in tqdm_override(samples, verbose=verbose, progress=progress, desc=desc):
|
||||
check_for_kill_signal()
|
||||
for i in range(batch.shape[0]):
|
||||
batch[i] = fix_autoregressive_output(batch[i], stop_mel_token)
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user