forked from mrq/tortoise-tts
(maybe) fixed an issue with using prompt redactions (emotions) on CPU causing a crash, because for some reason the wav2vec_alignment assumed CUDA was always available
This commit is contained in:
parent
d6b5d67f79
commit
5f934c5feb
5
app.py
5
app.py
|
@ -181,7 +181,7 @@ def generate(text, delimiter, emotion, prompt, voice, mic_audio, seed, candidate
|
||||||
if sample_voice is not None:
|
if sample_voice is not None:
|
||||||
sample_voice = (22050, sample_voice.squeeze().cpu().numpy())
|
sample_voice = (22050, sample_voice.squeeze().cpu().numpy())
|
||||||
|
|
||||||
print(f"Saved to '{outdir}'")
|
print(f"Generation took {info['time']} seconds, saved to '{outdir}'\n")
|
||||||
|
|
||||||
info['seed'] = settings['use_deterministic_seed']
|
info['seed'] = settings['use_deterministic_seed']
|
||||||
del info['latents']
|
del info['latents']
|
||||||
|
@ -332,9 +332,6 @@ def export_exec_settings( share, check_for_updates, low_vram, cond_latent_max_ch
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
if not torch.cuda.is_available():
|
|
||||||
print("CUDA is NOT available for use.")
|
|
||||||
|
|
||||||
with gr.Blocks() as webui:
|
with gr.Blocks() as webui:
|
||||||
with gr.Tab("Generate"):
|
with gr.Tab("Generate"):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
|
|
|
@ -226,13 +226,21 @@ class TextToSpeech:
|
||||||
Default is true.
|
Default is true.
|
||||||
:param device: Device to use when running the model. If omitted, the device will be automatically chosen.
|
:param device: Device to use when running the model. If omitted, the device will be automatically chosen.
|
||||||
"""
|
"""
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
print("CUDA is NOT available for use.")
|
||||||
|
# minor_optimizations = False
|
||||||
|
# enable_redaction = False
|
||||||
|
|
||||||
|
if device is None:
|
||||||
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
|
||||||
self.minor_optimizations = minor_optimizations
|
self.minor_optimizations = minor_optimizations
|
||||||
self.models_dir = models_dir
|
self.models_dir = models_dir
|
||||||
self.autoregressive_batch_size = pick_best_batch_size_for_gpu() if autoregressive_batch_size is None or autoregressive_batch_size == 0 else autoregressive_batch_size
|
self.autoregressive_batch_size = pick_best_batch_size_for_gpu() if autoregressive_batch_size is None or autoregressive_batch_size == 0 else autoregressive_batch_size
|
||||||
self.enable_redaction = enable_redaction
|
self.enable_redaction = enable_redaction
|
||||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
self.device = device
|
||||||
if self.enable_redaction:
|
if self.enable_redaction:
|
||||||
self.aligner = Wav2VecAlignment()
|
self.aligner = Wav2VecAlignment(device=self.device)
|
||||||
|
|
||||||
self.tokenizer = VoiceBpeTokenizer()
|
self.tokenizer = VoiceBpeTokenizer()
|
||||||
|
|
||||||
|
|
|
@ -49,7 +49,10 @@ class Wav2VecAlignment:
|
||||||
"""
|
"""
|
||||||
Uses wav2vec2 to perform audio<->text alignment.
|
Uses wav2vec2 to perform audio<->text alignment.
|
||||||
"""
|
"""
|
||||||
def __init__(self, device='cuda'):
|
def __init__(self, device=None):
|
||||||
|
if device is None:
|
||||||
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
|
||||||
self.model = Wav2Vec2ForCTC.from_pretrained("jbetker/wav2vec2-large-robust-ft-libritts-voxpopuli").cpu()
|
self.model = Wav2Vec2ForCTC.from_pretrained("jbetker/wav2vec2-large-robust-ft-libritts-voxpopuli").cpu()
|
||||||
self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(f"facebook/wav2vec2-large-960h")
|
self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(f"facebook/wav2vec2-large-960h")
|
||||||
self.tokenizer = Wav2Vec2CTCTokenizer.from_pretrained('jbetker/tacotron-symbols')
|
self.tokenizer = Wav2Vec2CTCTokenizer.from_pretrained('jbetker/tacotron-symbols')
|
||||||
|
@ -59,12 +62,16 @@ class Wav2VecAlignment:
|
||||||
orig_len = audio.shape[-1]
|
orig_len = audio.shape[-1]
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
self.model = self.model.to(self.device)
|
if torch.cuda.is_available(): # This is unneccessary technically, but it's a placebo
|
||||||
|
self.model = self.model.to(self.device)
|
||||||
|
|
||||||
audio = audio.to(self.device)
|
audio = audio.to(self.device)
|
||||||
audio = torchaudio.functional.resample(audio, audio_sample_rate, 16000)
|
audio = torchaudio.functional.resample(audio, audio_sample_rate, 16000)
|
||||||
clip_norm = (audio - audio.mean()) / torch.sqrt(audio.var() + 1e-7)
|
clip_norm = (audio - audio.mean()) / torch.sqrt(audio.var() + 1e-7)
|
||||||
logits = self.model(clip_norm).logits
|
logits = self.model(clip_norm).logits
|
||||||
self.model = self.model.cpu()
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
self.model = self.model.cpu()
|
||||||
|
|
||||||
logits = logits[0]
|
logits = logits[0]
|
||||||
pred_string = self.tokenizer.decode(logits.argmax(-1).tolist())
|
pred_string = self.tokenizer.decode(logits.argmax(-1).tolist())
|
||||||
|
|
Loading…
Reference in New Issue
Block a user