vall-e/vall_e/inference.py

198 lines
5.5 KiB
Python
Raw Normal View History

2023-08-02 21:53:35 +00:00
import torch
import torchaudio
import soundfile
2023-08-21 02:36:02 +00:00
from torch import Tensor
2023-08-02 21:53:35 +00:00
from einops import rearrange
2023-08-21 02:36:02 +00:00
from pathlib import Path
2023-08-02 21:53:35 +00:00
from .emb import g2p, qnt
from .emb.qnt import trim, trim_random
2023-08-02 21:53:35 +00:00
from .utils import to_device
from .config import cfg
from .models import get_models
from .engines import load_engines, deepspeed_available
2023-08-25 04:33:36 +00:00
from .data import get_phone_symmap, _load_quants
2023-08-02 21:53:35 +00:00
if deepspeed_available:
import deepspeed
2023-08-02 21:53:35 +00:00
class TTS():
def __init__( self, config=None, ar_ckpt=None, nar_ckpt=None, device=None, amp=None, dtype=None ):
2023-08-02 21:53:35 +00:00
self.loading = True
2023-08-02 21:53:35 +00:00
self.input_sample_rate = 24000
self.output_sample_rate = 24000
2023-08-14 03:56:28 +00:00
if config:
cfg.load_yaml( config )
cfg.dataset.use_hdf5 = False # could use cfg.load_hdf5(), but why would it ever need to be loaded for inferencing
2023-08-16 02:58:16 +00:00
try:
cfg.format()
except Exception as e:
pass
if amp is None:
amp = cfg.inference.amp
if dtype is None:
dtype = cfg.inference.weight_dtype
if device is None:
device = cfg.device
cfg.device = device
cfg.mode = "inferencing"
cfg.trainer.backend = cfg.inference.backend
cfg.trainer.weight_dtype = dtype
cfg.inference.weight_dtype = dtype
self.device = device
self.dtype = cfg.inference.dtype
self.amp = amp
self.symmap = None
def parse( name, model, state ):
if "userdata" in state and 'symmap' in state['userdata']:
self.symmap = state['userdata']['symmap']
elif "symmap" in state:
self.symmap = state['symmap']
if "module" in state:
state = state['module']
model.load_state_dict(state)
if deepspeed_available:
model = deepspeed.init_inference(model=model, mp_size=1, replace_with_kernel_inject=True, dtype=dtype if not amp else torch.float32).module
return model
2023-08-02 21:53:35 +00:00
if ar_ckpt and nar_ckpt:
self.ar_ckpt = ar_ckpt
self.nar_ckpt = nar_ckpt
models = get_models(cfg.models.get())
for name, model in models.items():
if name.startswith("ar"):
state = torch.load(self.ar_ckpt)
self.ar = parse( name, model, state )
elif name.startswith("nar"):
state = torch.load(self.nar_ckpt)
self.nar = parse( name, model, state )
if name.startswith("ar+nar"):
self.nar = self.ar
2023-08-02 21:53:35 +00:00
else:
2023-08-14 03:56:28 +00:00
self.load_models()
2023-08-02 21:53:35 +00:00
self.ar = self.ar.to(self.device, dtype=self.dtype if not self.amp else torch.float32)
self.nar = self.nar.to(self.device, dtype=self.dtype if not self.amp else torch.float32)
self.ar.eval()
self.nar.eval()
if self.symmap is None:
self.symmap = get_phone_symmap()
2023-08-02 21:53:35 +00:00
self.loading = False
2023-08-14 03:56:28 +00:00
def load_models( self ):
2023-08-16 02:58:16 +00:00
engines = load_engines()
for name, engine in engines.items():
if name.startswith("ar"):
self.ar = engine.module
elif name.startswith("nar"):
self.nar = engine.module
if name.startswith("ar+nar"):
self.nar = self.ar
2023-08-02 21:53:35 +00:00
2023-08-21 02:36:02 +00:00
def encode_text( self, text, language="en" ):
# already a tensor, return it
if isinstance( text, Tensor ):
return text
content = g2p.encode(text, language=language)
2023-08-25 04:33:36 +00:00
# ick
try:
phones = ["<s>"] + [ " " if not p else p for p in content ] + ["</s>"]
return torch.tensor([*map(self.symmap.get, phones)])
except Exception as e:
pass
phones = [ " " if not p else p for p in content ]
return torch.tensor([ 1 ] + [*map(self.symmap.get, phones)] + [ 2 ])
2023-08-02 21:53:35 +00:00
def encode_audio( self, paths, trim_length=0.0 ):
2023-08-21 02:36:02 +00:00
# already a tensor, return it
if isinstance( paths, Tensor ):
return paths
# split string into paths
if isinstance( paths, str ):
paths = [ Path(p) for p in paths.split(";") ]
# merge inputs
2023-08-25 04:33:36 +00:00
res = torch.cat([qnt.encode_from_file(path)[0][:, :].t().to(torch.int16) for path in paths])
if trim_length:
res = trim( res, int( 75 * trim_length ) )
2023-08-21 02:36:02 +00:00
2023-08-16 02:58:16 +00:00
return res
2023-08-02 21:53:35 +00:00
2023-08-21 02:36:02 +00:00
@torch.inference_mode()
def inference(
self,
text,
references,
max_ar_steps=6 * 75,
max_nar_levels=7,
input_prompt_length=0.0,
ar_temp=0.95,
nar_temp=0.5,
top_p=1.0,
top_k=0,
repetition_penalty=1.0,
repetition_penalty_decay=0.0,
length_penalty=0.0,
beam_width=0,
mirostat_tau=0,
mirostat_eta=0.1,
out_path=None
):
2023-08-21 02:36:02 +00:00
if out_path is None:
out_path = f"./data/{cfg.start_time}.wav"
2023-08-02 21:53:35 +00:00
prom = self.encode_audio( references, trim_length=input_prompt_length )
phns = self.encode_text( text )
2023-08-02 21:53:35 +00:00
prom = to_device(prom, self.device).to(torch.int16)
phns = to_device(phns, self.device).to(torch.uint8 if len(self.symmap) < 256 else torch.int16)
with torch.autocast("cuda", dtype=self.dtype, enabled=self.amp):
resps_list = self.ar(
text_list=[phns], proms_list=[prom], max_steps=max_ar_steps,
sampling_temperature=ar_temp,
sampling_top_p=top_p, sampling_top_k=top_k,
sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay,
sampling_length_penalty=length_penalty,
sampling_beam_width=beam_width,
sampling_mirostat_tau=mirostat_tau,
sampling_mirostat_eta=mirostat_eta,
)
resps_list = [r.unsqueeze(-1) for r in resps_list]
resps_list = self.nar(
text_list=[phns], proms_list=[prom], resps_list=resps_list,
max_levels=max_nar_levels,
sampling_temperature=nar_temp,
sampling_top_p=top_p, sampling_top_k=top_k,
sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay,
)
2023-08-02 21:53:35 +00:00
wav, sr = qnt.decode_to_file(resps_list[0], out_path, device=self.device)
2023-08-02 21:53:35 +00:00
return (wav, sr)