tortoise-tts/tortoise_tts/inference.py

264 lines
7.4 KiB
Python
Executable File

import torch
import torchaudio
import soundfile
from torch import Tensor
from einops import rearrange
from pathlib import Path
from .emb.mel import encode_from_files as encode_mel, trim, trim_random
from .utils import to_device
from .utils import wrapper as ml
from .config import cfg
from .models import get_models, load_model
from .engines import load_engines, deepspeed_available
from .data import get_phone_symmap, tokenize
from .models.arch_utils import denormalize_tacotron_mel
from .models.diffusion import get_diffuser
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F
if deepspeed_available:
import deepspeed
class TTS():
def __init__( self, config=None, device=None, amp=None, dtype=None ):
self.loading = True
self.input_sample_rate = 24000
self.output_sample_rate = 24000
if config:
cfg.load_yaml( config )
try:
cfg.format( training=False )
cfg.dataset.use_hdf5 = False # could use cfg.load_hdf5(), but why would it ever need to be loaded for inferencing
except Exception as e:
print("Error while parsing config YAML:")
raise e # throw an error because I'm tired of silent errors messing things up for me
if amp is None:
amp = cfg.inference.amp
if dtype is None or dtype == "auto":
dtype = cfg.inference.weight_dtype
if device is None:
device = cfg.device
cfg.device = device
cfg.mode = "inferencing"
cfg.trainer.backend = cfg.inference.backend
cfg.trainer.weight_dtype = dtype
cfg.inference.weight_dtype = dtype
self.device = device
self.dtype = cfg.inference.dtype
self.amp = amp
self.symmap = None
self.engines = load_engines(training=False)
for name, engine in self.engines.items():
if self.dtype != torch.int8:
engine.to(self.device, dtype=self.dtype if not self.amp else torch.float32)
self.engines.eval()
if self.symmap is None:
self.symmap = get_phone_symmap()
self.loading = False
def encode_text( self, text, language="en" ):
# already a tensor, return it
if isinstance( text, Tensor ):
return text
tokens = tokenize( text )
return torch.tensor( tokens )
def encode_audio( self, paths, trim_length=0.0 ):
# already a tensor, return it
if isinstance( paths, Tensor ):
return paths
# split string into paths
if isinstance( paths, str ):
paths = [ Path(p) for p in paths.split(";") ]
# merge inputs
return encode_mel( paths, device=self.device )
@torch.inference_mode()
def inference(
self,
text,
references,
#language="en",
max_ar_steps=500,
max_diffusion_steps=80,
#max_ar_context=-1,
#input_prompt_length=0.0,
ar_temp=1.0,
diffusion_temp=1.0,
#min_ar_temp=0.95,
#min_diffusion_temp=0.5,
top_p=1.0,
top_k=0,
repetition_penalty=1.0,
#repetition_penalty_decay=0.0,
length_penalty=1.0,
beam_width=1,
#mirostat_tau=0,
#mirostat_eta=0.1,
diffusion_sampler="ddim",
cond_free=True,
out_path=None
):
lines = text.split("\n")
wavs = []
sr = 24_000
autoregressive = None
diffusion = None
clvp = None
vocoder = None
diffuser = get_diffuser(steps=max_diffusion_steps, cond_free=cond_free)
autoregressive_latents, diffusion_latents = self.encode_audio( references )["latent"]
for name, engine in self.engines.items():
if "autoregressive" in name:
autoregressive = engine.module
elif "diffusion" in name:
diffusion = engine.module
elif "clvp" in name:
clvp = engine.module
elif "vocoder" in name:
vocoder = engine.module
if autoregressive is None:
autoregressive = load_model("autoregressive", device=cfg.device)
if diffusion is None:
diffusion = load_model("diffusion", device=cfg.device)
if clvp is None:
clvp = load_model("clvp", device=cfg.device)
if vocoder is None:
vocoder = load_model("vocoder", device=cfg.device)
# shove everything to cpu
if cfg.inference.auto_unload:
autoregressive = autoregressive.to("cpu")
diffusion = diffusion.to("cpu")
clvp = clvp.to("cpu")
vocoder = vocoder.to("cpu")
wavs = []
# other vars
calm_token = 832
for line in lines:
if out_path is None:
output_dir = Path("./data/results/")
if not output_dir.exists():
output_dir.mkdir(parents=True, exist_ok=True)
out_path = output_dir / f"{cfg.start_time}.wav"
text = self.encode_text( line ).to(device=cfg.device)
text_tokens = pad_sequence([ text ], batch_first = True)
text_lengths = torch.Tensor([ text.shape[0] ]).to(dtype=torch.int32)
with torch.autocast("cuda", dtype=cfg.inference.dtype, enabled=cfg.inference.amp):
with ml.auto_unload(autoregressive, enabled=cfg.inference.auto_unload):
# autoregressive pass
codes = autoregressive.inference_speech(
autoregressive_latents,
text_tokens,
do_sample=True,
top_p=top_p,
temperature=ar_temp,
num_return_sequences=1,
length_penalty=length_penalty,
repetition_penalty=repetition_penalty,
max_generate_length=max_ar_steps,
)
"""
padding_needed = max_ar_steps - codes.shape[1]
codes = F.pad(codes, (0, padding_needed), value=autoregressive.stop_mel_token)
"""
for i, code in enumerate( codes ):
stop_token_indices = (codes[i] == autoregressive.stop_mel_token).nonzero()
stm = stop_token_indices.min().item()
if len(stop_token_indices) == 0:
continue
codes[i][stop_token_indices] = 83
codes[i][stm:] = 83
if stm - 3 < codes[i].shape[0]:
codes[i][-3] = 45
codes[i][-2] = 45
codes[i][-1] = 248
wav_lengths = torch.tensor([codes.shape[-1] * autoregressive.mel_length_compression], device=text_tokens.device)
latents = autoregressive.forward(
autoregressive_latents,
text_tokens,
text_lengths,
codes,
wav_lengths,
return_latent=True,
clip_inputs=False
)
calm_tokens = 0
for k in range( codes.shape[-1] ):
if codes[0, k] == calm_token:
calm_tokens += 1
else:
calm_tokens = 0
if calm_tokens > 8: # 8 tokens gives the diffusion model some "breathing room" to terminate speech.
latents = latents[:, :k]
break
# diffusion pass
with ml.auto_unload(diffusion, enabled=cfg.inference.auto_unload):
output_seq_len = latents.shape[1] * 4 * 24000 // 22050 # This diffusion model converts from 22kHz spectrogram codes to a 24kHz spectrogram signal.
output_shape = (latents.shape[0], 100, output_seq_len)
precomputed_embeddings = diffusion.timestep_independent(latents, diffusion_latents, output_seq_len, False)
noise = torch.randn(output_shape, device=latents.device) * diffusion_temp
mel = diffuser.sample_loop(
diffusion,
output_shape,
sampler=diffusion_sampler,
noise=noise,
model_kwargs={'precomputed_aligned_embeddings': precomputed_embeddings},
progress=True
)
mels = denormalize_tacotron_mel(mel)[:,:,:output_seq_len]
# vocoder pass
with ml.auto_unload(vocoder, enabled=cfg.inference.auto_unload):
waves = vocoder.inference(mels)
for wav in waves:
if out_path is not None:
torchaudio.save( out_path, wav.cpu(), sr )
wavs.append(wav)
return (torch.concat(wavs, dim=-1), sr)