2023-08-02 21:53:35 +00:00
|
|
|
import torch
|
|
|
|
import torchaudio
|
|
|
|
import soundfile
|
2024-06-25 18:41:29 +00:00
|
|
|
import time
|
2024-08-29 18:27:16 +00:00
|
|
|
import logging
|
2024-11-10 04:57:34 +00:00
|
|
|
import numpy as np
|
2024-08-29 18:27:16 +00:00
|
|
|
|
|
|
|
_logger = logging.getLogger(__name__)
|
2023-08-02 21:53:35 +00:00
|
|
|
|
2023-08-21 02:36:02 +00:00
|
|
|
from torch import Tensor
|
2023-08-02 21:53:35 +00:00
|
|
|
from einops import rearrange
|
2023-08-21 02:36:02 +00:00
|
|
|
from pathlib import Path
|
2023-08-02 21:53:35 +00:00
|
|
|
|
|
|
|
from .emb import g2p, qnt
|
2024-10-16 00:25:03 +00:00
|
|
|
from .emb.qnt import trim, trim_random, unload_model, repeat_extend_audio
|
2024-11-10 04:57:34 +00:00
|
|
|
from .utils import to_device, set_seed, clamp, wrapper as ml
|
2023-08-02 21:53:35 +00:00
|
|
|
|
2024-07-16 00:59:48 +00:00
|
|
|
from .config import cfg, Config
|
2023-08-14 03:07:45 +00:00
|
|
|
from .models import get_models
|
2024-10-10 18:40:25 +00:00
|
|
|
from .models.lora import enable_lora
|
2023-10-09 20:24:04 +00:00
|
|
|
from .engines import load_engines, deepspeed_available
|
2024-12-12 02:55:43 +00:00
|
|
|
from .data import get_phone_symmap, get_lang_symmap, tokenize, sentence_split
|
2024-10-26 03:15:15 +00:00
|
|
|
from .models import download_model, DEFAULT_MODEL_PATH
|
2023-08-02 21:53:35 +00:00
|
|
|
|
2023-10-09 20:24:04 +00:00
|
|
|
if deepspeed_available:
|
2023-10-09 19:46:17 +00:00
|
|
|
import deepspeed
|
|
|
|
|
2023-08-02 21:53:35 +00:00
|
|
|
class TTS():
|
2024-10-26 05:13:10 +00:00
|
|
|
def __init__( self, config=None, lora=None, device=None, amp=None, dtype=None, attention=None ):
|
2023-08-02 21:53:35 +00:00
|
|
|
self.loading = True
|
2023-08-14 03:56:28 +00:00
|
|
|
|
2024-08-27 00:33:51 +00:00
|
|
|
# yes I can just grab **kwargs and forward them here
|
2024-10-26 05:13:10 +00:00
|
|
|
self.load_config( config=config, lora=lora, device=device, amp=amp, dtype=dtype, attention=attention )
|
2024-07-16 00:59:48 +00:00
|
|
|
self.load_model()
|
|
|
|
|
|
|
|
self.loading = False
|
|
|
|
|
2024-10-26 05:13:10 +00:00
|
|
|
def load_config( self, config=None, lora=None, device=None, amp=None, dtype=None, attention=None ):
|
2024-10-26 03:15:15 +00:00
|
|
|
if not config:
|
|
|
|
download_model()
|
|
|
|
config = DEFAULT_MODEL_PATH
|
|
|
|
|
|
|
|
if config.suffix == ".yaml":
|
2024-08-29 18:27:16 +00:00
|
|
|
_logger.info(f"Loading YAML: {config}")
|
2023-08-14 03:56:28 +00:00
|
|
|
cfg.load_yaml( config )
|
2024-10-26 03:15:15 +00:00
|
|
|
elif config.suffix == ".sft":
|
|
|
|
_logger.info(f"Loading model: {config}")
|
2024-10-26 05:13:10 +00:00
|
|
|
cfg.load_model( config, lora )
|
2024-10-26 03:15:15 +00:00
|
|
|
else:
|
|
|
|
raise Exception(f"Unknown config passed: {config}")
|
2023-08-16 02:58:16 +00:00
|
|
|
|
2024-11-10 18:19:48 +00:00
|
|
|
cfg.format( training=False )
|
|
|
|
cfg.dataset.use_hdf5 = False # could use cfg.load_hdf5(), but why would it ever need to be loaded for inferencing
|
2023-09-02 01:58:29 +00:00
|
|
|
|
2023-09-09 21:17:20 +00:00
|
|
|
if amp is None:
|
|
|
|
amp = cfg.inference.amp
|
2023-10-13 03:21:43 +00:00
|
|
|
if dtype is None or dtype == "auto":
|
2023-09-21 00:10:59 +00:00
|
|
|
dtype = cfg.inference.weight_dtype
|
2023-09-09 21:17:20 +00:00
|
|
|
if device is None:
|
|
|
|
device = cfg.device
|
|
|
|
|
|
|
|
cfg.device = device
|
2023-10-09 20:24:04 +00:00
|
|
|
cfg.mode = "inferencing"
|
|
|
|
cfg.trainer.backend = cfg.inference.backend
|
2023-09-09 21:17:20 +00:00
|
|
|
cfg.trainer.weight_dtype = dtype
|
|
|
|
cfg.inference.weight_dtype = dtype
|
|
|
|
|
|
|
|
self.device = device
|
|
|
|
self.dtype = cfg.inference.dtype
|
|
|
|
self.amp = amp
|
2024-12-08 01:21:05 +00:00
|
|
|
self.batch_size = cfg.inference.batch_size
|
2024-07-16 00:59:48 +00:00
|
|
|
|
2024-08-27 00:33:51 +00:00
|
|
|
self.model_kwargs = {}
|
|
|
|
if attention:
|
|
|
|
self.model_kwargs["attention"] = attention
|
2023-09-09 21:17:20 +00:00
|
|
|
|
2024-07-16 00:59:48 +00:00
|
|
|
def load_model( self ):
|
|
|
|
load_engines.cache_clear()
|
|
|
|
unload_model()
|
|
|
|
|
2024-08-27 00:33:51 +00:00
|
|
|
self.engines = load_engines(training=False, **self.model_kwargs)
|
2024-06-06 14:48:43 +00:00
|
|
|
for name, engine in self.engines.items():
|
|
|
|
if self.dtype != torch.int8:
|
|
|
|
engine.to(self.device, dtype=self.dtype if not self.amp else torch.float32)
|
|
|
|
|
|
|
|
self.engines.eval()
|
2024-07-16 00:59:48 +00:00
|
|
|
self.symmap = get_phone_symmap()
|
2024-08-29 18:27:16 +00:00
|
|
|
_logger.info("Loaded model")
|
2023-08-02 21:53:35 +00:00
|
|
|
|
2024-10-10 18:40:25 +00:00
|
|
|
def enable_lora( self, enabled=True ):
|
|
|
|
for name, engine in self.engines.items():
|
|
|
|
enable_lora( engine.module, mode = enabled )
|
|
|
|
|
|
|
|
def disable_lora( self ):
|
|
|
|
return self.enable_lora( enabled=False )
|
|
|
|
|
2024-12-08 04:34:25 +00:00
|
|
|
def encode_text( self, text, language="auto", precheck=True ):
|
2023-08-21 02:36:02 +00:00
|
|
|
# already a tensor, return it
|
|
|
|
if isinstance( text, Tensor ):
|
|
|
|
return text
|
|
|
|
|
2024-12-08 04:34:25 +00:00
|
|
|
# check if tokenizes without any unks (for example, if already phonemized text is passes)
|
|
|
|
if precheck and "<unk>" in self.symmap:
|
|
|
|
tokens = tokenize( text )
|
|
|
|
if self.symmap["<unk>"] not in tokens:
|
|
|
|
return torch.tensor( tokens )
|
|
|
|
|
2023-08-21 02:36:02 +00:00
|
|
|
content = g2p.encode(text, language=language)
|
2024-04-30 03:14:01 +00:00
|
|
|
tokens = tokenize( content )
|
2024-04-21 19:49:18 +00:00
|
|
|
|
2024-04-30 03:14:01 +00:00
|
|
|
return torch.tensor( tokens )
|
2023-08-02 21:53:35 +00:00
|
|
|
|
2023-10-13 04:21:01 +00:00
|
|
|
def encode_lang( self, language ):
|
|
|
|
symmap = get_lang_symmap()
|
|
|
|
id = 0
|
|
|
|
if language in symmap:
|
|
|
|
id = symmap[language]
|
|
|
|
return torch.tensor([ id ])
|
|
|
|
|
2024-09-10 21:34:23 +00:00
|
|
|
# to-do: trim before quantizing, instead of after
|
2024-10-17 19:37:21 +00:00
|
|
|
def encode_audio( self, paths, trim_length=5.0 ):
|
2023-08-21 02:36:02 +00:00
|
|
|
# already a tensor, return it
|
|
|
|
if isinstance( paths, Tensor ):
|
|
|
|
return paths
|
|
|
|
|
|
|
|
# split string into paths
|
|
|
|
if isinstance( paths, str ):
|
|
|
|
paths = [ Path(p) for p in paths.split(";") ]
|
|
|
|
|
2024-12-08 01:21:05 +00:00
|
|
|
# not already a list
|
|
|
|
if isinstance( paths, Path ):
|
|
|
|
paths = [ paths ]
|
2024-05-25 16:07:52 +00:00
|
|
|
|
|
|
|
proms = []
|
|
|
|
|
2024-12-08 01:21:05 +00:00
|
|
|
# merge inputs
|
2024-05-25 16:07:52 +00:00
|
|
|
for path in paths:
|
|
|
|
prom = qnt.encode_from_file(path)
|
|
|
|
if hasattr( prom, "codes" ):
|
|
|
|
prom = prom.codes
|
|
|
|
prom = prom[0][:, :].t().to(torch.int16)
|
|
|
|
|
|
|
|
proms.append( prom )
|
|
|
|
|
|
|
|
res = torch.cat(proms)
|
2023-08-25 04:33:36 +00:00
|
|
|
|
2024-10-16 00:25:03 +00:00
|
|
|
if trim_length:
|
2024-10-16 00:30:43 +00:00
|
|
|
res = repeat_extend_audio( res, int( cfg.dataset.frames_per_second * trim_length ) )
|
2024-10-17 19:37:21 +00:00
|
|
|
#res = trim( res, int( cfg.dataset.frames_per_second * trim_length ) )
|
2023-08-21 02:36:02 +00:00
|
|
|
|
2023-08-16 02:58:16 +00:00
|
|
|
return res
|
2023-08-02 21:53:35 +00:00
|
|
|
|
2024-09-10 21:45:59 +00:00
|
|
|
@torch.inference_mode()
|
|
|
|
def text_embedding( self, input, prom=False ):
|
|
|
|
model = None
|
|
|
|
|
|
|
|
for name, engine in self.engines.items():
|
|
|
|
model = engine.module
|
|
|
|
break
|
|
|
|
|
|
|
|
if isinstance( input, str ):
|
|
|
|
input = cfg.tokenizer.encode(input)
|
|
|
|
|
|
|
|
if isinstance( input, list ):
|
|
|
|
input = torch.tensor( input, dtype=torch.uint8, device=self.device )
|
|
|
|
|
|
|
|
return model.text_emb( input )
|
|
|
|
|
2024-09-10 21:34:23 +00:00
|
|
|
@torch.inference_mode()
|
|
|
|
def audio_embedding( self, input, prom=False ):
|
|
|
|
model = None
|
|
|
|
|
|
|
|
for name, engine in self.engines.items():
|
|
|
|
model = engine.module
|
|
|
|
break
|
|
|
|
|
|
|
|
# im really not sure which way is the better way, since the proms_emb and resps_emb have different properties.......
|
|
|
|
if prom:
|
|
|
|
return model.proms_emb(
|
|
|
|
input,
|
|
|
|
quant_level=input.shape[-1] - 1,
|
|
|
|
offset=0,
|
|
|
|
sums=True,
|
|
|
|
)
|
|
|
|
return sum([ model.resps_emb(
|
|
|
|
input[:, :l+1],
|
|
|
|
offset = 0 if l == 0 else 1, # or maybe set to 1
|
|
|
|
quant_level = l,
|
|
|
|
sums = False
|
|
|
|
) for l in range( input.shape[-1] - 1 ) ])
|
|
|
|
|
2024-11-20 00:51:17 +00:00
|
|
|
def modality( self, modality ):
|
|
|
|
# cringe to handle the best default mode for a given model
|
|
|
|
if modality == "auto" and cfg.model.name in ["ar+nar", "nar-len"]:
|
|
|
|
modality = cfg.model.name
|
|
|
|
return modality
|
|
|
|
|
2024-12-08 01:21:05 +00:00
|
|
|
# makes use of being able to batch inputs seamlessly by automatically batching
|
|
|
|
# this is NOT the default because it absolutely cannot make use of rolling context / prefixing
|
|
|
|
@torch.inference_mode()
|
|
|
|
def batched_inference(
|
|
|
|
self,
|
|
|
|
texts,
|
|
|
|
references=None,
|
|
|
|
languages=None,
|
|
|
|
text_languages=None,
|
|
|
|
out_paths=None,
|
|
|
|
**sampling_kwargs,
|
|
|
|
):
|
|
|
|
batch_size = sampling_kwargs.pop("batch_size", self.batch_size)
|
|
|
|
input_prompt_length = sampling_kwargs.pop("input_prompt_length", 0)
|
|
|
|
modality = sampling_kwargs.pop("modality", "auto")
|
|
|
|
seed = sampling_kwargs.pop("seed", None)
|
|
|
|
tqdm = sampling_kwargs.pop("tqdm", True)
|
|
|
|
use_lora = sampling_kwargs.pop("use_lora", None)
|
|
|
|
dtype = sampling_kwargs.pop("dtype", self.dtype)
|
|
|
|
amp = sampling_kwargs.pop("amp", self.amp)
|
|
|
|
|
2024-12-08 04:34:25 +00:00
|
|
|
if batch_size < 1:
|
|
|
|
batch_size = 1
|
|
|
|
|
2024-12-08 01:21:05 +00:00
|
|
|
model_ar = None
|
|
|
|
model_len = None
|
|
|
|
model_nar = None
|
|
|
|
|
|
|
|
for name, engine in self.engines.items():
|
|
|
|
if model_ar is None and "ar" in engine.hyper_config.capabilities:
|
|
|
|
model_ar = engine.module
|
|
|
|
if model_len is None and "len" in engine.hyper_config.capabilities:
|
|
|
|
model_len = engine.module
|
|
|
|
if model_nar is None and "nar" in engine.hyper_config.capabilities:
|
|
|
|
model_nar = engine.module
|
|
|
|
|
|
|
|
modality = self.modality( modality )
|
|
|
|
# force AR+NAR
|
|
|
|
if modality == "ar+nar":
|
|
|
|
model_len = None
|
|
|
|
# force NAR-len
|
|
|
|
elif modality == "nar-len":
|
|
|
|
model_ar = None
|
|
|
|
|
|
|
|
samples = len(texts)
|
|
|
|
# fill with null input proms
|
|
|
|
if not references:
|
|
|
|
references = [ None for _ in range(samples) ]
|
|
|
|
# fill with english
|
|
|
|
if not languages:
|
2024-12-08 04:34:25 +00:00
|
|
|
languages = [ "auto" for _ in range(samples) ]
|
2024-12-08 01:21:05 +00:00
|
|
|
if not out_paths:
|
|
|
|
out_paths = [ None for _ in range(samples) ]
|
|
|
|
# use the audio language to phonemize the text
|
|
|
|
if not text_languages:
|
|
|
|
text_languages = languages
|
|
|
|
|
2024-12-12 04:45:38 +00:00
|
|
|
inputs = []
|
2024-12-08 01:21:05 +00:00
|
|
|
# tensorfy inputs
|
|
|
|
for i in range( samples ):
|
2024-12-08 04:34:25 +00:00
|
|
|
# detect language
|
|
|
|
if languages[i] == "auto":
|
|
|
|
languages[i] = g2p.detect_language( texts[i] )
|
|
|
|
|
2024-12-08 01:21:05 +00:00
|
|
|
texts[i] = self.encode_text( texts[i], language=text_languages[i] )
|
|
|
|
references[i] = self.encode_audio( references[i], trim_length=input_prompt_length ) if references[i] else None
|
|
|
|
languages[i] = self.encode_lang( languages[i] )
|
|
|
|
|
|
|
|
texts[i] = to_device(texts[i], device=self.device, dtype=torch.uint8 if len(self.symmap) < 256 else torch.int16)
|
|
|
|
references[i] = to_device(references[i], device=self.device, dtype=torch.int16)
|
|
|
|
languages[i] = to_device(languages[i], device=self.device, dtype=torch.uint8)
|
|
|
|
|
2024-12-12 04:45:38 +00:00
|
|
|
seq_len = texts[i].shape[0] + 1 + (references[i].shape[0] if references[i] is not None else 0) + 1
|
|
|
|
|
|
|
|
inputs.append((texts[i], references[i], languages[i], out_paths[i], seq_len))
|
|
|
|
|
|
|
|
# attempt to reduce padding
|
|
|
|
inputs.sort(key=lambda x: x[-1])
|
|
|
|
|
2024-12-08 01:21:05 +00:00
|
|
|
# create batches
|
|
|
|
batches = []
|
|
|
|
buffer = ([], [], [], [])
|
2024-12-12 04:45:38 +00:00
|
|
|
for batch in inputs:
|
2024-12-08 01:21:05 +00:00
|
|
|
# flush
|
|
|
|
if len(buffer[0]) >= batch_size:
|
|
|
|
batches.append(buffer)
|
|
|
|
buffer = ([], [], [], [])
|
|
|
|
|
|
|
|
# insert into buffer
|
2024-12-12 04:45:38 +00:00
|
|
|
for i, x in enumerate( batch[:-1] ):
|
2024-12-08 01:21:05 +00:00
|
|
|
buffer[i].append(x)
|
|
|
|
|
|
|
|
# flush
|
2024-12-12 01:30:05 +00:00
|
|
|
if buffer:
|
2024-12-08 01:21:05 +00:00
|
|
|
batches.append(buffer)
|
|
|
|
buffer = ([], [], [], [])
|
|
|
|
|
|
|
|
wavs = []
|
|
|
|
for texts, proms, langs, out_paths in batches:
|
|
|
|
seed = set_seed(seed)
|
|
|
|
batch_size = len(texts)
|
|
|
|
input_kwargs = dict(
|
|
|
|
text_list=texts,
|
|
|
|
proms_list=proms,
|
|
|
|
lang_list=langs,
|
|
|
|
disable_tqdm=not tqdm,
|
|
|
|
use_lora=use_lora,
|
|
|
|
)
|
|
|
|
|
2024-12-11 02:13:21 +00:00
|
|
|
with torch.autocast(self.device, dtype=dtype, enabled=amp):
|
2024-12-08 01:21:05 +00:00
|
|
|
if model_len is not None:
|
|
|
|
# extra kwargs
|
|
|
|
duration_padding = sampling_kwargs.pop("duration_padding", 1.05)
|
|
|
|
len_list = model_len( **input_kwargs, task_list=["len"]*batch_size, **{"max_duration": 5} ) # "max_duration" is max tokens
|
|
|
|
|
|
|
|
# add an additional X seconds
|
|
|
|
len_list = [ int(l * duration_padding) for l in len_list ]
|
|
|
|
|
|
|
|
resps_list = model_nar( **input_kwargs, len_list=len_list, task_list=["tts"]*batch_size,
|
|
|
|
**sampling_kwargs,
|
|
|
|
)
|
|
|
|
elif model_ar is not None:
|
|
|
|
resps_list = model_ar(
|
|
|
|
**input_kwargs, task_list=["tts"]*batch_size,
|
|
|
|
**sampling_kwargs,
|
|
|
|
)
|
|
|
|
|
|
|
|
resps_list = model_nar(
|
|
|
|
**input_kwargs, resps_list=resps_list, task_list=["tts"]*batch_size,
|
|
|
|
**sampling_kwargs,
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
raise Exception("!")
|
|
|
|
|
|
|
|
for resp, out_path in zip( resps_list, out_paths ):
|
|
|
|
if out_path:
|
|
|
|
wav, sr = qnt.decode_to_file(resp, out_path, device=self.device)
|
|
|
|
else:
|
|
|
|
wav, sr = qnt.decode(resp, device=self.device)
|
|
|
|
wavs.append(wav)
|
|
|
|
return wavs
|
|
|
|
|
|
|
|
# naive serial inferencing
|
|
|
|
# will automatically split a text into pieces (if requested) piece by piece
|
2023-08-21 02:36:02 +00:00
|
|
|
@torch.inference_mode()
|
2023-09-13 02:28:07 +00:00
|
|
|
def inference(
|
|
|
|
self,
|
|
|
|
text,
|
|
|
|
references,
|
2024-12-08 04:34:25 +00:00
|
|
|
language="auto",
|
2024-12-08 01:21:05 +00:00
|
|
|
text_language=None,
|
2024-09-06 04:21:18 +00:00
|
|
|
task="tts",
|
2024-11-12 02:21:16 +00:00
|
|
|
out_path=None,
|
|
|
|
**sampling_kwargs,
|
2023-09-13 02:28:07 +00:00
|
|
|
):
|
2024-12-08 01:21:05 +00:00
|
|
|
input_prompt_length = sampling_kwargs.pop("input_prompt_length", 0)
|
|
|
|
modality = sampling_kwargs.pop("modality", "auto")
|
|
|
|
seed = sampling_kwargs.pop("seed", None)
|
|
|
|
tqdm = sampling_kwargs.pop("tqdm", True)
|
|
|
|
use_lora = sampling_kwargs.pop("use_lora", None)
|
|
|
|
dtype = sampling_kwargs.pop("dtype", self.dtype)
|
|
|
|
amp = sampling_kwargs.pop("amp", self.amp)
|
|
|
|
|
2024-12-05 02:31:44 +00:00
|
|
|
lines = sentence_split(text, split_by=sampling_kwargs.get("split_text_by", "sentences"))
|
2023-12-26 03:20:32 +00:00
|
|
|
|
|
|
|
wavs = []
|
|
|
|
sr = None
|
|
|
|
|
2024-06-06 14:48:43 +00:00
|
|
|
model_ar = None
|
2024-06-13 00:49:47 +00:00
|
|
|
model_len = None
|
2024-06-06 14:48:43 +00:00
|
|
|
model_nar = None
|
|
|
|
|
|
|
|
for name, engine in self.engines.items():
|
2024-11-19 03:29:28 +00:00
|
|
|
if model_ar is None and "ar" in engine.hyper_config.capabilities:
|
2024-06-06 14:48:43 +00:00
|
|
|
model_ar = engine.module
|
2024-11-19 03:29:28 +00:00
|
|
|
if model_len is None and "len" in engine.hyper_config.capabilities:
|
2024-06-13 00:49:47 +00:00
|
|
|
model_len = engine.module
|
2024-11-19 03:29:28 +00:00
|
|
|
if model_nar is None and "nar" in engine.hyper_config.capabilities:
|
2024-06-06 14:48:43 +00:00
|
|
|
model_nar = engine.module
|
2024-06-25 18:41:29 +00:00
|
|
|
|
2024-10-18 21:55:00 +00:00
|
|
|
seed = set_seed(seed)
|
2024-06-06 14:48:43 +00:00
|
|
|
|
2024-11-20 00:51:17 +00:00
|
|
|
modality = self.modality( modality )
|
|
|
|
# force AR+NAR
|
|
|
|
if modality == "ar+nar":
|
|
|
|
model_len = None
|
|
|
|
# force NAR-len
|
|
|
|
elif modality == "nar-len":
|
|
|
|
model_ar = None
|
|
|
|
|
2024-09-06 04:21:18 +00:00
|
|
|
if task == "stt":
|
|
|
|
resp = self.encode_audio( references )
|
|
|
|
lang = self.encode_lang( language )
|
|
|
|
|
2024-09-06 20:13:04 +00:00
|
|
|
resp = to_device(resp, device=self.device, dtype=torch.int16)
|
2024-09-06 04:21:18 +00:00
|
|
|
lang = to_device(lang, device=self.device, dtype=torch.uint8)
|
|
|
|
|
2024-12-11 02:13:21 +00:00
|
|
|
with torch.autocast(self.device, dtype=dtype, enabled=amp):
|
2024-11-18 15:40:04 +00:00
|
|
|
model = model_ar if model_ar is not None else model_nar
|
|
|
|
if model is not None:
|
|
|
|
text_list = model(
|
2024-11-12 02:21:16 +00:00
|
|
|
text_list=None, proms_list=[resp], lang_list=[lang], resps_list=[resp], task_list=["stt"],
|
2024-09-06 04:21:18 +00:00
|
|
|
disable_tqdm=not tqdm,
|
2024-10-11 00:04:12 +00:00
|
|
|
use_lora=use_lora,
|
2024-11-12 02:21:16 +00:00
|
|
|
**sampling_kwargs,
|
2024-09-06 04:21:18 +00:00
|
|
|
)
|
|
|
|
else:
|
|
|
|
raise Exception("!")
|
|
|
|
|
2024-09-06 23:44:25 +00:00
|
|
|
text_list = [ cfg.tokenizer.decode( text ).replace(" ", "_").replace(" ", "").replace("_", " ") for text in text_list ]
|
2024-09-06 04:21:18 +00:00
|
|
|
|
|
|
|
return text_list[0]
|
|
|
|
|
2024-12-05 02:31:44 +00:00
|
|
|
# stuff for rolling context
|
|
|
|
prefix_context = None
|
|
|
|
prefix_contexts = []
|
|
|
|
context_history = sampling_kwargs.get("context_history", 0)
|
|
|
|
|
2024-12-08 04:57:29 +00:00
|
|
|
auto_lang = not language or language == "auto"
|
|
|
|
auto_text_lang = not text_language or text_language == "auto"
|
2023-12-26 03:20:32 +00:00
|
|
|
for line in lines:
|
|
|
|
if out_path is None:
|
2024-06-25 18:41:29 +00:00
|
|
|
output_dir = Path("./data/results/")
|
|
|
|
if not output_dir.exists():
|
|
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
out_path = output_dir / f"{time.time()}.wav"
|
2023-12-26 03:20:32 +00:00
|
|
|
|
2024-12-08 04:57:29 +00:00
|
|
|
deduced_language = g2p.detect_language( line ) if auto_lang or auto_text_lang else language
|
|
|
|
|
|
|
|
if auto_lang:
|
|
|
|
language = deduced_language
|
|
|
|
|
|
|
|
if auto_text_lang:
|
|
|
|
text_language = deduced_language
|
|
|
|
|
2024-07-23 00:36:07 +00:00
|
|
|
prom = self.encode_audio( references, trim_length=input_prompt_length ) if references else None
|
2024-12-07 03:55:20 +00:00
|
|
|
phns = self.encode_text( line, language=text_language )
|
2023-12-26 03:20:32 +00:00
|
|
|
lang = self.encode_lang( language )
|
|
|
|
|
2024-07-23 00:36:07 +00:00
|
|
|
prom = to_device(prom, device=self.device, dtype=torch.int16)
|
|
|
|
phns = to_device(phns, device=self.device, dtype=torch.uint8 if len(self.symmap) < 256 else torch.int16)
|
|
|
|
lang = to_device(lang, device=self.device, dtype=torch.uint8)
|
2023-12-26 03:20:32 +00:00
|
|
|
|
2024-12-11 02:13:21 +00:00
|
|
|
with torch.autocast(self.device, dtype=dtype, enabled=amp):
|
2024-12-08 01:21:05 +00:00
|
|
|
input_kwargs = dict(
|
|
|
|
text_list=[phns],
|
|
|
|
proms_list=[prom],
|
|
|
|
lang_list=[lang],
|
|
|
|
disable_tqdm=not tqdm,
|
|
|
|
use_lora=use_lora,
|
|
|
|
)
|
2024-11-18 15:40:04 +00:00
|
|
|
if model_len is not None:
|
2024-11-21 02:37:33 +00:00
|
|
|
# extra kwargs
|
2024-11-21 19:04:07 +00:00
|
|
|
duration_padding = sampling_kwargs.pop("duration_padding", 1.05)
|
2024-12-08 01:21:05 +00:00
|
|
|
len_list = model_len( **input_kwargs, task_list=["len"], **{"max_duration": 5} ) # "max_duration" is max tokens
|
2024-11-21 02:37:33 +00:00
|
|
|
|
|
|
|
# add an additional X seconds
|
2024-11-21 19:18:11 +00:00
|
|
|
len_list = [ int(l * duration_padding) for l in len_list ]
|
2024-11-21 02:37:33 +00:00
|
|
|
|
2024-11-10 04:57:34 +00:00
|
|
|
kwargs = {}
|
2024-12-05 02:31:44 +00:00
|
|
|
if prefix_context is not None:
|
|
|
|
kwargs["prefix_context"] = prefix_context
|
2024-11-10 04:57:34 +00:00
|
|
|
|
2024-12-08 01:21:05 +00:00
|
|
|
resps_list = model_nar( **input_kwargs, len_list=len_list, task_list=["tts"],
|
2024-11-12 02:21:16 +00:00
|
|
|
**(sampling_kwargs | kwargs),
|
2024-06-13 00:49:47 +00:00
|
|
|
)
|
2024-11-18 15:40:04 +00:00
|
|
|
elif model_ar is not None:
|
2024-12-05 02:31:44 +00:00
|
|
|
kwargs = {}
|
|
|
|
if prefix_context is not None:
|
|
|
|
kwargs["prefix_context"] = prefix_context
|
|
|
|
|
2024-11-18 15:40:04 +00:00
|
|
|
resps_list = model_ar(
|
2024-12-08 01:21:05 +00:00
|
|
|
**input_kwargs, task_list=["tts"],
|
2024-12-05 02:31:44 +00:00
|
|
|
**(sampling_kwargs | kwargs),
|
2024-11-18 15:40:04 +00:00
|
|
|
)
|
2024-12-05 02:31:44 +00:00
|
|
|
|
2024-11-18 15:40:04 +00:00
|
|
|
resps_list = model_nar(
|
2024-12-08 01:21:05 +00:00
|
|
|
**input_kwargs, resps_list=resps_list, task_list=["tts"],
|
2024-11-18 15:40:04 +00:00
|
|
|
**sampling_kwargs,
|
|
|
|
)
|
2024-06-13 00:49:47 +00:00
|
|
|
else:
|
|
|
|
raise Exception("!")
|
2023-12-26 03:20:32 +00:00
|
|
|
|
2024-12-05 02:31:44 +00:00
|
|
|
# to-do: care about batching later
|
|
|
|
resps = resps_list[0]
|
|
|
|
|
|
|
|
# store current context to use as the initial input for later
|
|
|
|
if context_history > 0:
|
|
|
|
# add to history
|
|
|
|
prefix_contexts.append(( phns, resps, resps.shape[0] ))
|
|
|
|
# then generate the prefix based on how much history to provide
|
|
|
|
prefix_context = (
|
|
|
|
[ torch.concat( [ x[0] for x in prefix_contexts[-context_history:] ] ) ],
|
|
|
|
[ torch.concat( [ x[1] for x in prefix_contexts[-context_history:] ] ) ],
|
|
|
|
[ sum([ x[2] for x in prefix_contexts[-context_history:] ]) ]
|
|
|
|
)
|
|
|
|
|
|
|
|
# write to file
|
|
|
|
wav, sr = qnt.decode_to_file(resps, out_path, device=self.device)
|
|
|
|
# add utterances
|
2023-12-26 03:20:32 +00:00
|
|
|
wavs.append(wav)
|
2024-12-05 02:31:44 +00:00
|
|
|
|
|
|
|
# combine all utterances
|
2023-12-26 03:20:32 +00:00
|
|
|
return (torch.concat(wavs, dim=-1), sr)
|
2023-08-02 21:53:35 +00:00
|
|
|
|