added more kill checks, since it only actually did it for the first iteration of a loop
This commit is contained in:
parent
de46cf7831
commit
7cc0250a1a
|
@ -42,12 +42,15 @@ MODELS = {
|
||||||
'rlg_diffuser.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/rlg_diffuser.pth',
|
'rlg_diffuser.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/rlg_diffuser.pth',
|
||||||
}
|
}
|
||||||
|
|
||||||
def tqdm_override(arr, verbose=False, progress=None, desc=None):
|
def check_for_kill_signal():
|
||||||
global STOP_SIGNAL
|
global STOP_SIGNAL
|
||||||
if STOP_SIGNAL:
|
if STOP_SIGNAL:
|
||||||
STOP_SIGNAL = False
|
STOP_SIGNAL = False
|
||||||
raise Exception("Kill signal detected")
|
raise Exception("Kill signal detected")
|
||||||
|
|
||||||
|
def tqdm_override(arr, verbose=False, progress=None, desc=None):
|
||||||
|
check_for_kill_signal()
|
||||||
|
|
||||||
if verbose and desc is not None:
|
if verbose and desc is not None:
|
||||||
print(desc)
|
print(desc)
|
||||||
|
|
||||||
|
@ -368,6 +371,7 @@ class TextToSpeech:
|
||||||
# expand / truncate samples to match the common size
|
# expand / truncate samples to match the common size
|
||||||
# required, as tensors need to be of the same length
|
# required, as tensors need to be of the same length
|
||||||
for chunk in tqdm_override(chunks, verbose=verbose, progress=progress, desc="Computing conditioning latents..."):
|
for chunk in tqdm_override(chunks, verbose=verbose, progress=progress, desc="Computing conditioning latents..."):
|
||||||
|
check_for_kill_signal()
|
||||||
chunk = pad_or_truncate(chunk, chunk_size)
|
chunk = pad_or_truncate(chunk, chunk_size)
|
||||||
cond_mel = wav_to_univnet_mel(chunk.to(device), do_normalization=False, device=device)
|
cond_mel = wav_to_univnet_mel(chunk.to(device), do_normalization=False, device=device)
|
||||||
diffusion_conds.append(cond_mel)
|
diffusion_conds.append(cond_mel)
|
||||||
|
@ -524,6 +528,7 @@ class TextToSpeech:
|
||||||
|
|
||||||
with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=half_p):
|
with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=half_p):
|
||||||
for b in tqdm_override(range(num_batches), verbose=verbose, progress=progress, desc="Generating autoregressive samples"):
|
for b in tqdm_override(range(num_batches), verbose=verbose, progress=progress, desc="Generating autoregressive samples"):
|
||||||
|
check_for_kill_signal()
|
||||||
codes = self.autoregressive.inference_speech(auto_conditioning, text_tokens,
|
codes = self.autoregressive.inference_speech(auto_conditioning, text_tokens,
|
||||||
do_sample=True,
|
do_sample=True,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
|
@ -565,6 +570,7 @@ class TextToSpeech:
|
||||||
desc = f"Computing best candidates using CLVP {((1-cvvp_amount) * 100):2.0f}% and CVVP {(cvvp_amount * 100):2.0f}%"
|
desc = f"Computing best candidates using CLVP {((1-cvvp_amount) * 100):2.0f}% and CVVP {(cvvp_amount * 100):2.0f}%"
|
||||||
|
|
||||||
for batch in tqdm_override(samples, verbose=verbose, progress=progress, desc=desc):
|
for batch in tqdm_override(samples, verbose=verbose, progress=progress, desc=desc):
|
||||||
|
check_for_kill_signal()
|
||||||
for i in range(batch.shape[0]):
|
for i in range(batch.shape[0]):
|
||||||
batch[i] = fix_autoregressive_output(batch[i], stop_mel_token)
|
batch[i] = fix_autoregressive_output(batch[i], stop_mel_token)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user