2023-08-02 21:53:35 +00:00
|
|
|
import torch
|
|
|
|
import torchaudio
|
|
|
|
import soundfile
|
|
|
|
|
2023-08-21 02:36:02 +00:00
|
|
|
from torch import Tensor
|
2023-08-02 21:53:35 +00:00
|
|
|
from einops import rearrange
|
2023-08-21 02:36:02 +00:00
|
|
|
from pathlib import Path
|
2023-08-02 21:53:35 +00:00
|
|
|
|
|
|
|
from .emb import g2p, qnt
|
2023-08-23 21:43:03 +00:00
|
|
|
from .emb.qnt import trim, trim_random
|
2023-08-02 21:53:35 +00:00
|
|
|
from .utils import to_device
|
|
|
|
|
|
|
|
from .config import cfg
|
2023-08-14 03:07:45 +00:00
|
|
|
from .models import get_models
|
2023-10-09 20:24:04 +00:00
|
|
|
from .engines import load_engines, deepspeed_available
|
2023-10-13 04:21:01 +00:00
|
|
|
from .data import get_phone_symmap, get_lang_symmap, _load_quants, _cleanup_phones
|
2023-08-02 21:53:35 +00:00
|
|
|
|
2023-10-09 20:24:04 +00:00
|
|
|
if deepspeed_available:
|
2023-10-09 19:46:17 +00:00
|
|
|
import deepspeed
|
|
|
|
|
2023-08-02 21:53:35 +00:00
|
|
|
class TTS():
|
2024-04-16 00:54:32 +00:00
|
|
|
def __init__( self, config=None, model_ckpt=None, device=None, amp=None, dtype=None ):
|
2023-08-02 21:53:35 +00:00
|
|
|
self.loading = True
|
2023-09-09 21:17:20 +00:00
|
|
|
|
2023-08-02 21:53:35 +00:00
|
|
|
self.input_sample_rate = 24000
|
|
|
|
self.output_sample_rate = 24000
|
2023-08-14 03:56:28 +00:00
|
|
|
|
|
|
|
if config:
|
|
|
|
cfg.load_yaml( config )
|
2023-08-23 21:43:03 +00:00
|
|
|
cfg.dataset.use_hdf5 = False # could use cfg.load_hdf5(), but why would it ever need to be loaded for inferencing
|
2023-08-16 02:58:16 +00:00
|
|
|
|
2023-08-21 00:21:54 +00:00
|
|
|
try:
|
|
|
|
cfg.format()
|
|
|
|
except Exception as e:
|
|
|
|
pass
|
2023-09-02 01:58:29 +00:00
|
|
|
|
2023-09-09 21:17:20 +00:00
|
|
|
if amp is None:
|
|
|
|
amp = cfg.inference.amp
|
2023-10-13 03:21:43 +00:00
|
|
|
if dtype is None or dtype == "auto":
|
2023-09-21 00:10:59 +00:00
|
|
|
dtype = cfg.inference.weight_dtype
|
2023-09-09 21:17:20 +00:00
|
|
|
if device is None:
|
|
|
|
device = cfg.device
|
|
|
|
|
|
|
|
cfg.device = device
|
2023-10-09 20:24:04 +00:00
|
|
|
cfg.mode = "inferencing"
|
|
|
|
cfg.trainer.backend = cfg.inference.backend
|
2023-09-09 21:17:20 +00:00
|
|
|
cfg.trainer.weight_dtype = dtype
|
|
|
|
cfg.inference.weight_dtype = dtype
|
|
|
|
|
|
|
|
self.device = device
|
|
|
|
self.dtype = cfg.inference.dtype
|
|
|
|
self.amp = amp
|
|
|
|
|
2023-08-23 21:43:03 +00:00
|
|
|
self.symmap = None
|
2023-09-21 00:10:59 +00:00
|
|
|
|
2024-04-16 00:54:32 +00:00
|
|
|
if model_ckpt:
|
|
|
|
state = torch.load(model_ckpt)
|
|
|
|
self.model = get_models(cfg.model.get(), training=False)[0]
|
|
|
|
|
2023-09-21 00:10:59 +00:00
|
|
|
if "userdata" in state and 'symmap' in state['userdata']:
|
|
|
|
self.symmap = state['userdata']['symmap']
|
|
|
|
elif "symmap" in state:
|
|
|
|
self.symmap = state['symmap']
|
|
|
|
|
|
|
|
if "module" in state:
|
|
|
|
state = state['module']
|
|
|
|
|
2024-04-16 00:54:32 +00:00
|
|
|
self.model.load_state_dict(state)
|
2023-10-09 20:24:04 +00:00
|
|
|
|
2023-10-13 03:21:43 +00:00
|
|
|
if cfg.inference.backend == "local" and deepspeed_available and cfg.trainer.deepspeed.inferencing:
|
2024-04-16 00:54:32 +00:00
|
|
|
self.model = deepspeed.init_inference(model=self.model, mp_size=1, replace_with_kernel_inject=True, dtype=dtype if not amp else torch.float32).module
|
2023-08-02 21:53:35 +00:00
|
|
|
else:
|
2024-04-16 00:54:32 +00:00
|
|
|
engines = load_engines(training=False)
|
|
|
|
for name, engine in engines.items():
|
|
|
|
self.model = engine.module
|
|
|
|
break
|
2023-08-02 21:53:35 +00:00
|
|
|
|
2023-10-13 03:21:43 +00:00
|
|
|
if self.dtype != torch.int8:
|
2024-04-16 00:54:32 +00:00
|
|
|
self.model = self.model.to(self.device, dtype=self.dtype if not self.amp else torch.float32)
|
2023-09-21 00:10:59 +00:00
|
|
|
|
2024-04-16 00:54:32 +00:00
|
|
|
self.model.eval()
|
2023-10-09 19:46:17 +00:00
|
|
|
|
2023-08-23 21:43:03 +00:00
|
|
|
if self.symmap is None:
|
|
|
|
self.symmap = get_phone_symmap()
|
|
|
|
|
2023-08-02 21:53:35 +00:00
|
|
|
self.loading = False
|
|
|
|
|
2023-08-21 02:36:02 +00:00
|
|
|
def encode_text( self, text, language="en" ):
|
|
|
|
# already a tensor, return it
|
|
|
|
if isinstance( text, Tensor ):
|
|
|
|
return text
|
|
|
|
|
|
|
|
content = g2p.encode(text, language=language)
|
2023-10-11 00:21:19 +00:00
|
|
|
content = _cleanup_phones( content )
|
2023-08-25 04:33:36 +00:00
|
|
|
# ick
|
|
|
|
try:
|
|
|
|
phones = ["<s>"] + [ " " if not p else p for p in content ] + ["</s>"]
|
|
|
|
return torch.tensor([*map(self.symmap.get, phones)])
|
|
|
|
except Exception as e:
|
|
|
|
pass
|
2023-08-21 00:21:54 +00:00
|
|
|
phones = [ " " if not p else p for p in content ]
|
|
|
|
return torch.tensor([ 1 ] + [*map(self.symmap.get, phones)] + [ 2 ])
|
2023-08-02 21:53:35 +00:00
|
|
|
|
2023-10-13 04:21:01 +00:00
|
|
|
def encode_lang( self, language ):
|
|
|
|
symmap = get_lang_symmap()
|
|
|
|
id = 0
|
|
|
|
if language in symmap:
|
|
|
|
id = symmap[language]
|
|
|
|
return torch.tensor([ id ])
|
|
|
|
|
2023-09-09 23:04:44 +00:00
|
|
|
def encode_audio( self, paths, trim_length=0.0 ):
|
2023-08-21 02:36:02 +00:00
|
|
|
# already a tensor, return it
|
|
|
|
if isinstance( paths, Tensor ):
|
|
|
|
return paths
|
|
|
|
|
|
|
|
# split string into paths
|
|
|
|
if isinstance( paths, str ):
|
|
|
|
paths = [ Path(p) for p in paths.split(";") ]
|
|
|
|
|
|
|
|
# merge inputs
|
2023-08-25 04:33:36 +00:00
|
|
|
res = torch.cat([qnt.encode_from_file(path)[0][:, :].t().to(torch.int16) for path in paths])
|
|
|
|
|
2023-09-09 23:04:44 +00:00
|
|
|
if trim_length:
|
|
|
|
res = trim( res, int( 75 * trim_length ) )
|
2023-08-21 02:36:02 +00:00
|
|
|
|
2023-08-16 02:58:16 +00:00
|
|
|
return res
|
2023-08-02 21:53:35 +00:00
|
|
|
|
2023-08-21 02:36:02 +00:00
|
|
|
@torch.inference_mode()
|
2023-09-13 02:28:07 +00:00
|
|
|
def inference(
|
|
|
|
self,
|
|
|
|
text,
|
|
|
|
references,
|
2023-10-13 04:21:01 +00:00
|
|
|
language="en",
|
2023-09-13 02:28:07 +00:00
|
|
|
max_ar_steps=6 * 75,
|
2023-10-12 01:38:40 +00:00
|
|
|
max_ar_context=-1,
|
2023-09-13 02:28:07 +00:00
|
|
|
max_nar_levels=7,
|
|
|
|
input_prompt_length=0.0,
|
|
|
|
ar_temp=0.95,
|
|
|
|
nar_temp=0.5,
|
2023-10-10 22:02:33 +00:00
|
|
|
min_ar_temp=0.95,
|
|
|
|
min_nar_temp=0.5,
|
2023-09-13 02:28:07 +00:00
|
|
|
top_p=1.0,
|
|
|
|
top_k=0,
|
|
|
|
repetition_penalty=1.0,
|
|
|
|
repetition_penalty_decay=0.0,
|
|
|
|
length_penalty=0.0,
|
|
|
|
beam_width=0,
|
2023-09-18 23:55:41 +00:00
|
|
|
mirostat_tau=0,
|
|
|
|
mirostat_eta=0.1,
|
2023-09-13 02:28:07 +00:00
|
|
|
out_path=None
|
|
|
|
):
|
2023-12-26 03:20:32 +00:00
|
|
|
lines = text.split("\n")
|
|
|
|
|
|
|
|
wavs = []
|
|
|
|
sr = None
|
|
|
|
|
|
|
|
for line in lines:
|
|
|
|
if out_path is None:
|
|
|
|
out_path = f"./data/{cfg.start_time}.wav"
|
|
|
|
|
|
|
|
prom = self.encode_audio( references, trim_length=input_prompt_length )
|
|
|
|
phns = self.encode_text( line, language=language )
|
|
|
|
lang = self.encode_lang( language )
|
|
|
|
|
|
|
|
prom = to_device(prom, self.device).to(torch.int16)
|
|
|
|
phns = to_device(phns, self.device).to(torch.uint8 if len(self.symmap) < 256 else torch.int16)
|
|
|
|
lang = to_device(lang, self.device).to(torch.uint8)
|
|
|
|
|
|
|
|
with torch.autocast("cuda", dtype=self.dtype, enabled=self.amp):
|
2024-04-16 00:54:32 +00:00
|
|
|
resps_list = self.model(
|
2023-12-26 03:20:32 +00:00
|
|
|
text_list=[phns], proms_list=[prom], lang_list=[lang], max_steps=max_ar_steps, max_resp_context=max_ar_context,
|
|
|
|
sampling_temperature=ar_temp,
|
|
|
|
sampling_min_temperature=min_ar_temp,
|
|
|
|
sampling_top_p=top_p, sampling_top_k=top_k,
|
|
|
|
sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay,
|
|
|
|
sampling_length_penalty=length_penalty,
|
|
|
|
sampling_beam_width=beam_width,
|
|
|
|
sampling_mirostat_tau=mirostat_tau,
|
|
|
|
sampling_mirostat_eta=mirostat_eta,
|
|
|
|
)
|
|
|
|
resps_list = [r.unsqueeze(-1) for r in resps_list]
|
2024-04-16 00:54:32 +00:00
|
|
|
resps_list = self.model(
|
2023-12-26 03:20:32 +00:00
|
|
|
text_list=[phns], proms_list=[prom], lang_list=[lang], resps_list=resps_list,
|
|
|
|
max_levels=max_nar_levels,
|
|
|
|
sampling_temperature=nar_temp,
|
|
|
|
sampling_min_temperature=min_nar_temp,
|
|
|
|
sampling_top_p=top_p, sampling_top_k=top_k,
|
|
|
|
sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay,
|
|
|
|
)
|
|
|
|
|
|
|
|
wav, sr = qnt.decode_to_file(resps_list[0], out_path, device=self.device)
|
|
|
|
wavs.append(wav)
|
2023-08-02 21:53:35 +00:00
|
|
|
|
2023-12-26 03:20:32 +00:00
|
|
|
return (torch.concat(wavs, dim=-1), sr)
|
2023-08-02 21:53:35 +00:00
|
|
|
|