forked from mrq/tortoise-tts
port do_tts to use the API
This commit is contained in:
parent
9db06e139b
commit
287debd1d3
29
api.py
29
api.py
|
@ -151,10 +151,10 @@ class TextToSpeech:
|
||||||
|
|
||||||
def tts(self, text, voice_samples, k=1,
|
def tts(self, text, voice_samples, k=1,
|
||||||
# autoregressive generation parameters follow
|
# autoregressive generation parameters follow
|
||||||
num_autoregressive_samples=512, temperature=.9, length_penalty=1, repetition_penalty=1.0, top_k=50, top_p=.95,
|
num_autoregressive_samples=512, temperature=.5, length_penalty=2, repetition_penalty=2.0, top_p=.5,
|
||||||
typical_sampling=False, typical_mass=.9,
|
typical_sampling=False, typical_mass=.9,
|
||||||
# diffusion generation parameters follow
|
# diffusion generation parameters follow
|
||||||
diffusion_iterations=100, cond_free=True, cond_free_k=1, diffusion_temperature=1,):
|
diffusion_iterations=100, cond_free=True, cond_free_k=2, diffusion_temperature=.7,):
|
||||||
text = torch.IntTensor(self.tokenizer.encode(text)).unsqueeze(0).cuda()
|
text = torch.IntTensor(self.tokenizer.encode(text)).unsqueeze(0).cuda()
|
||||||
text = F.pad(text, (0, 1)) # This may not be necessary.
|
text = F.pad(text, (0, 1)) # This may not be necessary.
|
||||||
|
|
||||||
|
@ -181,7 +181,6 @@ class TextToSpeech:
|
||||||
for b in tqdm(range(num_batches)):
|
for b in tqdm(range(num_batches)):
|
||||||
codes = self.autoregressive.inference_speech(conds, text,
|
codes = self.autoregressive.inference_speech(conds, text,
|
||||||
do_sample=True,
|
do_sample=True,
|
||||||
top_k=top_k,
|
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
num_return_sequences=self.autoregressive_batch_size,
|
num_return_sequences=self.autoregressive_batch_size,
|
||||||
|
@ -221,3 +220,27 @@ class TextToSpeech:
|
||||||
if len(wav_candidates) > 1:
|
if len(wav_candidates) > 1:
|
||||||
return wav_candidates
|
return wav_candidates
|
||||||
return wav_candidates[0]
|
return wav_candidates[0]
|
||||||
|
|
||||||
|
def refine_for_intellibility(self, wav_candidates, corresponding_codes, output_path):
|
||||||
|
"""
|
||||||
|
Further refine the remaining candidates using a ASR model to pick out the ones that are the most understandable.
|
||||||
|
TODO: finish this function
|
||||||
|
:param wav_candidates:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
transcriber = ocotillo.Transcriber(on_cuda=True)
|
||||||
|
transcriptions = transcriber.transcribe_batch(torch.cat(wav_candidates, dim=0).squeeze(1), 24000)
|
||||||
|
best = 99999999
|
||||||
|
for i, transcription in enumerate(transcriptions):
|
||||||
|
dist = lev_distance(transcription, args.text.lower())
|
||||||
|
if dist < best:
|
||||||
|
best = dist
|
||||||
|
best_codes = corresponding_codes[i].unsqueeze(0)
|
||||||
|
best_wav = wav_candidates[i]
|
||||||
|
del transcriber
|
||||||
|
torchaudio.save(os.path.join(output_path, f'{voice}_poor.wav'), best_wav.squeeze(0).cpu(), 24000)
|
||||||
|
|
||||||
|
# Perform diffusion again with the high-quality diffuser.
|
||||||
|
mel = do_spectrogram_diffusion(diffusion, final_diffuser, best_codes, cond_diffusion, mean=False)
|
||||||
|
wav = vocoder.inference(mel)
|
||||||
|
torchaudio.save(os.path.join(args.output_path, f'{voice}.wav'), wav.squeeze(0).cpu(), 24000)
|
206
do_tts.py
206
do_tts.py
|
@ -1,123 +1,13 @@
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
import random
|
|
||||||
from urllib import request
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torchaudio
|
import torchaudio
|
||||||
import progressbar
|
|
||||||
import ocotillo
|
|
||||||
|
|
||||||
from models.diffusion_decoder import DiffusionTts
|
|
||||||
from models.autoregressive import UnifiedVoice
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
from models.arch_util import TorchMelSpectrogram
|
|
||||||
from models.text_voice_clip import VoiceCLIP
|
|
||||||
from models.vocoder import UnivNetGenerator
|
|
||||||
from utils.audio import load_audio, wav_to_univnet_mel, denormalize_tacotron_mel
|
|
||||||
from utils.diffusion import SpacedDiffusion, space_timesteps, get_named_beta_schedule
|
|
||||||
from utils.tokenizer import VoiceBpeTokenizer, lev_distance
|
|
||||||
|
|
||||||
pbar = None
|
|
||||||
def download_models():
|
|
||||||
MODELS = {
|
|
||||||
'clip.pth': 'https://huggingface.co/jbetker/tortoise-tts-clip/resolve/main/pytorch-model.bin',
|
|
||||||
'diffusion.pth': 'https://huggingface.co/jbetker/tortoise-tts-diffusion-v1/resolve/main/pytorch-model.bin',
|
|
||||||
'autoregressive.pth': 'https://huggingface.co/jbetker/tortoise-tts-autoregressive/resolve/main/pytorch-model.bin'
|
|
||||||
}
|
|
||||||
os.makedirs('.models', exist_ok=True)
|
|
||||||
def show_progress(block_num, block_size, total_size):
|
|
||||||
global pbar
|
|
||||||
if pbar is None:
|
|
||||||
pbar = progressbar.ProgressBar(maxval=total_size)
|
|
||||||
pbar.start()
|
|
||||||
|
|
||||||
downloaded = block_num * block_size
|
|
||||||
if downloaded < total_size:
|
|
||||||
pbar.update(downloaded)
|
|
||||||
else:
|
|
||||||
pbar.finish()
|
|
||||||
pbar = None
|
|
||||||
for model_name, url in MODELS.items():
|
|
||||||
if os.path.exists(f'.models/{model_name}'):
|
|
||||||
continue
|
|
||||||
print(f'Downloading {model_name} from {url}...')
|
|
||||||
request.urlretrieve(url, f'.models/{model_name}', show_progress)
|
|
||||||
print('Done.')
|
|
||||||
|
|
||||||
|
|
||||||
def load_discrete_vocoder_diffuser(trained_diffusion_steps=4000, desired_diffusion_steps=200, cond_free=True):
|
|
||||||
"""
|
|
||||||
Helper function to load a GaussianDiffusion instance configured for use as a vocoder.
|
|
||||||
"""
|
|
||||||
return SpacedDiffusion(use_timesteps=space_timesteps(trained_diffusion_steps, [desired_diffusion_steps]), model_mean_type='epsilon',
|
|
||||||
model_var_type='learned_range', loss_type='mse', betas=get_named_beta_schedule('linear', trained_diffusion_steps),
|
|
||||||
conditioning_free=cond_free, conditioning_free_k=1)
|
|
||||||
|
|
||||||
|
|
||||||
def load_conditioning(path, sample_rate=22050, cond_length=132300):
|
|
||||||
rel_clip = load_audio(path, sample_rate)
|
|
||||||
gap = rel_clip.shape[-1] - cond_length
|
|
||||||
if gap < 0:
|
|
||||||
rel_clip = F.pad(rel_clip, pad=(0, abs(gap)))
|
|
||||||
elif gap > 0:
|
|
||||||
rand_start = random.randint(0, gap)
|
|
||||||
rel_clip = rel_clip[:, rand_start:rand_start + cond_length]
|
|
||||||
mel_clip = TorchMelSpectrogram()(rel_clip.unsqueeze(0)).squeeze(0)
|
|
||||||
return mel_clip.unsqueeze(0).cuda(), rel_clip.unsqueeze(0).cuda()
|
|
||||||
|
|
||||||
|
|
||||||
def fix_autoregressive_output(codes, stop_token):
|
|
||||||
"""
|
|
||||||
This function performs some padding on coded audio that fixes a mismatch issue between what the diffusion model was
|
|
||||||