2023-08-02 21:53:35 +00:00
|
|
|
import torch
|
|
|
|
import torchaudio
|
|
|
|
import soundfile
|
|
|
|
|
|
|
|
from einops import rearrange
|
|
|
|
|
|
|
|
from .emb import g2p, qnt
|
|
|
|
from .utils import to_device
|
|
|
|
|
|
|
|
from .config import cfg
|
2023-08-14 03:07:45 +00:00
|
|
|
from .models import get_models
|
2023-08-16 02:58:16 +00:00
|
|
|
from .train import load_engines
|
2023-08-14 03:07:45 +00:00
|
|
|
from .data import get_phone_symmap
|
2023-08-02 21:53:35 +00:00
|
|
|
|
2023-08-16 02:58:16 +00:00
|
|
|
import random
|
|
|
|
|
|
|
|
def trim( qnt, trim_length ):
|
|
|
|
length = qnt.shape[0]
|
|
|
|
start = int(length * random.random())
|
|
|
|
end = start + trim_length
|
|
|
|
if end >= length:
|
|
|
|
start = length - trim_length
|
|
|
|
end = length
|
|
|
|
return qnt[start:end]
|
|
|
|
|
2023-08-02 21:53:35 +00:00
|
|
|
class TTS():
|
|
|
|
def __init__( self, config=None, ar_ckpt=None, nar_ckpt=None, device="cuda" ):
|
|
|
|
self.loading = True
|
|
|
|
self.device = device
|
|
|
|
|
|
|
|
self.input_sample_rate = 24000
|
|
|
|
self.output_sample_rate = 24000
|
2023-08-14 03:56:28 +00:00
|
|
|
|
|
|
|
if config:
|
|
|
|
cfg.load_yaml( config )
|
2023-08-16 02:58:16 +00:00
|
|
|
|
|
|
|
cfg.format()
|
2023-08-18 19:47:48 +00:00
|
|
|
|
|
|
|
"""
|
|
|
|
if cfg.trainer.load_state_dict:
|
|
|
|
for model in cfg.models.get():
|
|
|
|
path = cfg.ckpt_dir / model.full_name / "fp32.pth"
|
|
|
|
if model.name.startswith("ar"):
|
|
|
|
ar_ckpt = path
|
|
|
|
if model.name.startswith("nar"):
|
|
|
|
nar_ckpt = path
|
|
|
|
"""
|
2023-08-02 21:53:35 +00:00
|
|
|
|
|
|
|
if ar_ckpt and nar_ckpt:
|
2023-08-14 03:07:45 +00:00
|
|
|
self.ar_ckpt = ar_ckpt
|
|
|
|
self.nar_ckpt = nar_ckpt
|
|
|
|
|
|
|
|
models = get_models(cfg.models.get())
|
|
|
|
for name, model in models.items():
|
|
|
|
if name.startswith("ar"):
|
|
|
|
self.ar = model.to(self.device, dtype=torch.float32)
|
2023-08-18 19:47:48 +00:00
|
|
|
state = torch.load(self.ar_ckpt)
|
|
|
|
if "module" in state:
|
|
|
|
state = state['module']
|
|
|
|
self.ar.load_state_dict(state)
|
2023-08-14 03:07:45 +00:00
|
|
|
elif name.startswith("nar"):
|
|
|
|
self.nar = model.to(self.device, dtype=torch.float32)
|
2023-08-18 19:47:48 +00:00
|
|
|
state = torch.load(self.nar_ckpt)
|
|
|
|
if "module" in state:
|
|
|
|
state = state['module']
|
|
|
|
self.nar.load_state_dict(state)
|
2023-08-02 21:53:35 +00:00
|
|
|
else:
|
2023-08-14 03:56:28 +00:00
|
|
|
self.load_models()
|
2023-08-02 21:53:35 +00:00
|
|
|
|
2023-08-14 03:07:45 +00:00
|
|
|
self.symmap = get_phone_symmap()
|
2023-08-05 04:29:05 +00:00
|
|
|
self.ar.eval()
|
|
|
|
self.nar.eval()
|
|
|
|
|
2023-08-02 21:53:35 +00:00
|
|
|
self.loading = False
|
|
|
|
|
2023-08-14 03:56:28 +00:00
|
|
|
def load_models( self ):
|
2023-08-16 02:58:16 +00:00
|
|
|
engines = load_engines()
|
|
|
|
for name, engine in engines.items():
|
2023-08-02 21:53:35 +00:00
|
|
|
if name[:2] == "ar":
|
2023-08-16 02:58:16 +00:00
|
|
|
self.ar = engine.module.to(self.device)
|
2023-08-02 21:53:35 +00:00
|
|
|
elif name[:3] == "nar":
|
2023-08-16 02:58:16 +00:00
|
|
|
self.nar = engine.module.to(self.device)
|
2023-08-02 21:53:35 +00:00
|
|
|
|
|
|
|
def encode_text( self, text, lang_marker="en" ):
|
|
|
|
text = g2p.encode(text)
|
|
|
|
phones = [f"<{lang_marker}>"] + [ " " if not p else p for p in text ] + [f"</{lang_marker}>"]
|
|
|
|
mapped = [self.symmap[p] for p in phones if p in self.symmap]
|
|
|
|
return torch.tensor( mapped )
|
|
|
|
|
|
|
|
def encode_audio( self, path ):
|
|
|
|
enc = qnt.encode_from_file( path )
|
2023-08-16 02:58:16 +00:00
|
|
|
res = enc[0].t().to(torch.int16)
|
|
|
|
if trim:
|
|
|
|
res = trim( res, int( 75 * cfg.dataset.duration_range[1] ) )
|
|
|
|
return res
|
2023-08-02 21:53:35 +00:00
|
|
|
|
|
|
|
|
|
|
|
def inference( self, text, reference, mode="both", max_ar_steps=6 * 75, ar_temp=1.0, nar_temp=1.0, out_path="./.tmp.wav" ):
|
|
|
|
prom = self.encode_audio( reference )
|
2023-08-14 03:07:45 +00:00
|
|
|
phns = self.encode_text( text )
|
2023-08-02 21:53:35 +00:00
|
|
|
|
|
|
|
prom = to_device(prom, self.device).to(torch.int16)
|
|
|
|
phns = to_device(phns, self.device).to(torch.uint8 if len(self.symmap) < 256 else torch.int16)
|
|
|
|
|
2023-08-04 01:26:36 +00:00
|
|
|
resps_list = self.ar(text_list=[phns], proms_list=[prom], max_steps=max_ar_steps, sampling_temperature=ar_temp)
|
|
|
|
resps_list = [r.unsqueeze(-1) for r in resps_list]
|
2023-08-02 21:53:35 +00:00
|
|
|
resps_list = self.nar(text_list=[phns], proms_list=[prom], resps_list=resps_list, sampling_temperature=nar_temp)
|
|
|
|
|
|
|
|
wav, sr = qnt.decode_to_file(resps_list[0], out_path)
|
|
|
|
|
|
|
|
return (wav, sr)
|
|
|
|
|