forked from mrq/tortoise-tts
(maybe) fixed an issue with using prompt redactions (emotions) on CPU causing a crash, because for some reason the wav2vec_alignment assumed CUDA was always available
This commit is contained in:
parent
d6b5d67f79
commit
5f934c5feb
5
app.py
5
app.py
|
@ -181,7 +181,7 @@ def generate(text, delimiter, emotion, prompt, voice, mic_audio, seed, candidate
|
|||
if sample_voice is not None:
|
||||
sample_voice = (22050, sample_voice.squeeze().cpu().numpy())
|
||||
|
||||
print(f"Saved to '{outdir}'")
|
||||
print(f"Generation took {info['time']} seconds, saved to '{outdir}'\n")
|
||||
|
||||
info['seed'] = settings['use_deterministic_seed']
|
||||
del info['latents']
|
||||
|
@ -332,9 +332,6 @@ def export_exec_settings( share, check_for_updates, low_vram, cond_latent_max_ch
|
|||
|
||||
|
||||
def main():
|
||||
if not torch.cuda.is_available():
|
||||
print("CUDA is NOT available for use.")
|
||||
|
||||
with gr.Blocks() as webui:
|
||||
with gr.Tab("Generate"):
|
||||
with gr.Row():
|
||||
|
|
|
@ -226,13 +226,21 @@ class TextToSpeech:
|
|||
Default is true.
|
||||
:param device: Device to use when running the model. If omitted, the device will be automatically chosen.
|
||||
"""
|
||||
if not torch.cuda.is_available():
|
||||
print("CUDA is NOT available for use.")
|
||||
# minor_optimizations = False
|
||||
# enable_redaction = False
|
||||
|
||||
if device is None:
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
self.minor_optimizations = minor_optimizations
|
||||
self.models_dir = models_dir
|
||||
self.autoregressive_batch_size = pick_best_batch_size_for_gpu() if autoregressive_batch_size is None or autoregressive_batch_size == 0 else autoregressive_batch_size
|
||||
self.enable_redaction = enable_redaction
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
self.device = device
|
||||
if self.enable_redaction:
|
||||
self.aligner = Wav2VecAlignment()
|
||||
self.aligner = Wav2VecAlignment(device=self.device)
|
||||
|
||||
self.tokenizer = VoiceBpeTokenizer()
|
||||
|
||||
|
|
|
@ -49,7 +49,10 @@ class Wav2VecAlignment:
|
|||
"""
|
||||
Uses wav2vec2 to perform audio<->text alignment.
|
||||
"""
|
||||
def __init__(self, device='cuda'):
|
||||
def __init__(self, device=None):
|
||||
if device is None:
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
self.model = Wav2Vec2ForCTC.from_pretrained("jbetker/wav2vec2-large-robust-ft-libritts-voxpopuli").cpu()
|
||||
self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(f"facebook/wav2vec2-large-960h")
|
||||
self.tokenizer = Wav2Vec2CTCTokenizer.from_pretrained('jbetker/tacotron-symbols')
|
||||
|
@ -59,12 +62,16 @@ class Wav2VecAlignment:
|
|||
orig_len = audio.shape[-1]
|
||||
|
||||
with torch.no_grad():
|
||||
self.model = self.model.to(self.device)
|
||||
if torch.cuda.is_available(): # This is unneccessary technically, but it's a placebo
|
||||
self.model = self.model.to(self.device)
|
||||
|
||||
audio = audio.to(self.device)
|
||||
audio = torchaudio.functional.resample(audio, audio_sample_rate, 16000)
|
||||
clip_norm = (audio - audio.mean()) / torch.sqrt(audio.var() + 1e-7)
|
||||
logits = self.model(clip_norm).logits
|
||||
self.model = self.model.cpu()
|
||||
|
||||
if torch.cuda.is_available():
|
||||
self.model = self.model.cpu()
|
||||
|
||||
logits = logits[0]
|
||||
pred_string = self.tokenizer.decode(logits.argmax(-1).tolist())
|
||||
|
|
Loading…
Reference in New Issue
Block a user