port do_tts to use the API

This commit is contained in:
James Betker 2022-04-01 11:55:07 -06:00
parent 9db06e139b
commit 287debd1d3
2 changed files with 36 additions and 201 deletions

29
api.py
View File

@ -151,10 +151,10 @@ class TextToSpeech:
def tts(self, text, voice_samples, k=1, def tts(self, text, voice_samples, k=1,
# autoregressive generation parameters follow # autoregressive generation parameters follow
num_autoregressive_samples=512, temperature=.9, length_penalty=1, repetition_penalty=1.0, top_k=50, top_p=.95, num_autoregressive_samples=512, temperature=.5, length_penalty=2, repetition_penalty=2.0, top_p=.5,
typical_sampling=False, typical_mass=.9, typical_sampling=False, typical_mass=.9,
# diffusion generation parameters follow # diffusion generation parameters follow
diffusion_iterations=100, cond_free=True, cond_free_k=1, diffusion_temperature=1,): diffusion_iterations=100, cond_free=True, cond_free_k=2, diffusion_temperature=.7,):
text = torch.IntTensor(self.tokenizer.encode(text)).unsqueeze(0).cuda() text = torch.IntTensor(self.tokenizer.encode(text)).unsqueeze(0).cuda()
text = F.pad(text, (0, 1)) # This may not be necessary. text = F.pad(text, (0, 1)) # This may not be necessary.
@ -181,7 +181,6 @@ class TextToSpeech:
for b in tqdm(range(num_batches)): for b in tqdm(range(num_batches)):
codes = self.autoregressive.inference_speech(conds, text, codes = self.autoregressive.inference_speech(conds, text,
do_sample=True, do_sample=True,
top_k=top_k,
top_p=top_p, top_p=top_p,
temperature=temperature, temperature=temperature,
num_return_sequences=self.autoregressive_batch_size, num_return_sequences=self.autoregressive_batch_size,
@ -221,3 +220,27 @@ class TextToSpeech:
if len(wav_candidates) > 1: if len(wav_candidates) > 1:
return wav_candidates return wav_candidates
return wav_candidates[0] return wav_candidates[0]
def refine_for_intellibility(self, wav_candidates, corresponding_codes, output_path):
"""
Further refine the remaining candidates using a ASR model to pick out the ones that are the most understandable.
TODO: finish this function
:param wav_candidates:
:return:
"""
transcriber = ocotillo.Transcriber(on_cuda=True)
transcriptions = transcriber.transcribe_batch(torch.cat(wav_candidates, dim=0).squeeze(1), 24000)
best = 99999999
for i, transcription in enumerate(transcriptions):
dist = lev_distance(transcription, args.text.lower())
if dist < best:
best = dist
best_codes = corresponding_codes[i].unsqueeze(0)
best_wav = wav_candidates[i]
del transcriber
torchaudio.save(os.path.join(output_path, f'{voice}_poor.wav'), best_wav.squeeze(0).cpu(), 24000)
# Perform diffusion again with the high-quality diffuser.
mel = do_spectrogram_diffusion(diffusion, final_diffuser, best_codes, cond_diffusion, mean=False)
wav = vocoder.inference(mel)
torchaudio.save(os.path.join(args.output_path, f'{voice}.wav'), wav.squeeze(0).cpu(), 24000)

206
do_tts.py
View File

@ -1,123 +1,13 @@
import argparse import argparse
import os import os
import random
from urllib import request
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import torchaudio import torchaudio
import progressbar
import ocotillo
from models.diffusion_decoder import DiffusionTts
from models.autoregressive import UnifiedVoice
from tqdm import tqdm
from models.arch_util import TorchMelSpectrogram
from models.text_voice_clip import VoiceCLIP
from models.vocoder import UnivNetGenerator
from utils.audio import load_audio, wav_to_univnet_mel, denormalize_tacotron_mel
from utils.diffusion import SpacedDiffusion, space_timesteps, get_named_beta_schedule
from utils.tokenizer import VoiceBpeTokenizer, lev_distance
pbar = None
def download_models():
MODELS = {
'clip.pth': 'https://huggingface.co/jbetker/tortoise-tts-clip/resolve/main/pytorch-model.bin',
'diffusion.pth': 'https://huggingface.co/jbetker/tortoise-tts-diffusion-v1/resolve/main/pytorch-model.bin',
'autoregressive.pth': 'https://huggingface.co/jbetker/tortoise-tts-autoregressive/resolve/main/pytorch-model.bin'
}
os.makedirs('.models', exist_ok=True)
def show_progress(block_num, block_size, total_size):
global pbar
if pbar is None:
pbar = progressbar.ProgressBar(maxval=total_size)
pbar.start()
downloaded = block_num * block_size
if downloaded < total_size:
pbar.update(downloaded)
else:
pbar.finish()
pbar = None
for model_name, url in MODELS.items():
if os.path.exists(f'.models/{model_name}'):
continue
print(f'Downloading {model_name} from {url}...')
request.urlretrieve(url, f'.models/{model_name}', show_progress)
print('Done.')
def load_discrete_vocoder_diffuser(trained_diffusion_steps=4000, desired_diffusion_steps=200, cond_free=True):
"""
Helper function to load a GaussianDiffusion instance configured for use as a vocoder.
"""
return SpacedDiffusion(use_timesteps=space_timesteps(trained_diffusion_steps, [desired_diffusion_steps]), model_mean_type='epsilon',
model_var_type='learned_range', loss_type='mse', betas=get_named_beta_schedule('linear', trained_diffusion_steps),
conditioning_free=cond_free, conditioning_free_k=1)
def load_conditioning(path, sample_rate=22050, cond_length=132300):
rel_clip = load_audio(path, sample_rate)
gap = rel_clip.shape[-1] - cond_length
if gap < 0:
rel_clip = F.pad(rel_clip, pad=(0, abs(gap)))
elif gap > 0:
rand_start = random.randint(0, gap)
rel_clip = rel_clip[:, rand_start:rand_start + cond_length]
mel_clip = TorchMelSpectrogram()(rel_clip.unsqueeze(0)).squeeze(0)
return mel_clip.unsqueeze(0).cuda(), rel_clip.unsqueeze(0).cuda()
def fix_autoregressive_output(codes, stop_token):
"""
This function performs some padding on coded audio that fixes a mismatch issue between what the diffusion model was