forked from mrq/tortoise-tts
Remove intelligibility refinement
It's not longer a concern. :)
This commit is contained in:
parent
2eb5d4b0cb
commit
cf80d7317c
26
api.py
26
api.py
|
@ -5,9 +5,7 @@ from urllib import request
|
|||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torchaudio
|
||||
import progressbar
|
||||
import ocotillo
|
||||
|
||||
from models.diffusion_decoder import DiffusionTts
|
||||
from models.autoregressive import UnifiedVoice
|
||||
|
@ -262,27 +260,3 @@ class TextToSpeech:
|
|||
if len(wav_candidates) > 1:
|
||||
return wav_candidates
|
||||
return wav_candidates[0]
|
||||
|
||||
def refine_for_intellibility(self, wav_candidates, corresponding_codes, output_path):
|
||||
"""
|
||||
Further refine the remaining candidates using a ASR model to pick out the ones that are the most understandable.
|
||||
TODO: finish this function
|
||||
:param wav_candidates:
|
||||
:return:
|
||||
"""
|
||||
transcriber = ocotillo.Transcriber(on_cuda=True)
|
||||
transcriptions = transcriber.transcribe_batch(torch.cat(wav_candidates, dim=0).squeeze(1), 24000)
|
||||
best = 99999999
|
||||
for i, transcription in enumerate(transcriptions):
|
||||
dist = lev_distance(transcription, args.text.lower())
|
||||
if dist < best:
|
||||
best = dist
|
||||
best_codes = corresponding_codes[i].unsqueeze(0)
|
||||
best_wav = wav_candidates[i]
|
||||
del transcriber
|
||||
torchaudio.save(os.path.join(output_path, f'{voice}_poor.wav'), best_wav.squeeze(0).cpu(), 24000)
|
||||
|
||||
# Perform diffusion again with the high-quality diffuser.
|
||||
mel = do_spectrogram_diffusion(diffusion, final_diffuser, best_codes, cond_diffusion, mean=False)
|
||||
wav = vocoder.inference(mel)
|
||||
torchaudio.save(os.path.join(args.output_path, f'{voice}.wav'), wav.squeeze(0).cpu(), 24000)
|
|
@ -8,4 +8,3 @@ progressbar
|
|||
einops
|
||||
unidecode
|
||||
x-transformers
|
||||
ocotillo
|
Loading…
Reference in New Issue
Block a user