forked from mrq/tortoise-tts
Remove intelligibility refinement
It's not longer a concern. :)
This commit is contained in:
parent
2eb5d4b0cb
commit
cf80d7317c
26
api.py
26
api.py
|
@ -5,9 +5,7 @@ from urllib import request
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torchaudio
|
|
||||||
import progressbar
|
import progressbar
|
||||||
import ocotillo
|
|
||||||
|
|
||||||
from models.diffusion_decoder import DiffusionTts
|
from models.diffusion_decoder import DiffusionTts
|
||||||
from models.autoregressive import UnifiedVoice
|
from models.autoregressive import UnifiedVoice
|
||||||
|
@ -262,27 +260,3 @@ class TextToSpeech:
|
||||||
if len(wav_candidates) > 1:
|
if len(wav_candidates) > 1:
|
||||||
return wav_candidates
|
return wav_candidates
|
||||||
return wav_candidates[0]
|
return wav_candidates[0]
|
||||||
|
|
||||||
def refine_for_intellibility(self, wav_candidates, corresponding_codes, output_path):
|
|
||||||
"""
|
|
||||||
Further refine the remaining candidates using a ASR model to pick out the ones that are the most understandable.
|
|
||||||
TODO: finish this function
|
|
||||||
:param wav_candidates:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
transcriber = ocotillo.Transcriber(on_cuda=True)
|
|
||||||
transcriptions = transcriber.transcribe_batch(torch.cat(wav_candidates, dim=0).squeeze(1), 24000)
|
|
||||||
best = 99999999
|
|
||||||
for i, transcription in enumerate(transcriptions):
|
|
||||||
dist = lev_distance(transcription, args.text.lower())
|
|
||||||
if dist < best:
|
|
||||||
best = dist
|
|
||||||
best_codes = corresponding_codes[i].unsqueeze(0)
|
|
||||||
best_wav = wav_candidates[i]
|
|
||||||
del transcriber
|
|
||||||
torchaudio.save(os.path.join(output_path, f'{voice}_poor.wav'), best_wav.squeeze(0).cpu(), 24000)
|
|
||||||
|
|
||||||
# Perform diffusion again with the high-quality diffuser.
|
|
||||||
mel = do_spectrogram_diffusion(diffusion, final_diffuser, best_codes, cond_diffusion, mean=False)
|
|
||||||
wav = vocoder.inference(mel)
|
|
||||||
torchaudio.save(os.path.join(args.output_path, f'{voice}.wav'), wav.squeeze(0).cpu(), 24000)
|
|
|
@ -8,4 +8,3 @@ progressbar
|
||||||
einops
|
einops
|
||||||
unidecode
|
unidecode
|
||||||
x-transformers
|
x-transformers
|
||||||
ocotillo
|
|
Loading…
Reference in New Issue
Block a user