vall-e/vall_e/inference.py

342 lines
10 KiB
Python
Raw Normal View History

2023-08-02 21:53:35 +00:00
import torch
import torchaudio
import soundfile
import time
import logging
import numpy as np
_logger = logging.getLogger(__name__)
2023-08-02 21:53:35 +00:00
2023-08-21 02:36:02 +00:00
from torch import Tensor
2023-08-02 21:53:35 +00:00
from einops import rearrange
2023-08-21 02:36:02 +00:00
from pathlib import Path
2023-08-02 21:53:35 +00:00
from .emb import g2p, qnt
from .emb.qnt import trim, trim_random, unload_model, repeat_extend_audio
from .utils import to_device, set_seed, clamp, wrapper as ml
2023-08-02 21:53:35 +00:00
from .config import cfg, Config
from .models import get_models
from .models.lora import enable_lora
from .engines import load_engines, deepspeed_available
from .data import get_phone_symmap, get_lang_symmap, _load_quants, _cleanup_phones, tokenize, sentence_split
from .models import download_model, DEFAULT_MODEL_PATH
2023-08-02 21:53:35 +00:00
if deepspeed_available:
import deepspeed
2023-08-02 21:53:35 +00:00
class TTS():
def __init__( self, config=None, lora=None, device=None, amp=None, dtype=None, attention=None ):
2023-08-02 21:53:35 +00:00
self.loading = True
2023-08-14 03:56:28 +00:00
# yes I can just grab **kwargs and forward them here
self.load_config( config=config, lora=lora, device=device, amp=amp, dtype=dtype, attention=attention )
self.load_model()
self.loading = False
def load_config( self, config=None, lora=None, device=None, amp=None, dtype=None, attention=None ):
if not config:
download_model()
config = DEFAULT_MODEL_PATH
if config.suffix == ".yaml":
_logger.info(f"Loading YAML: {config}")
2023-08-14 03:56:28 +00:00
cfg.load_yaml( config )
elif config.suffix == ".sft":
_logger.info(f"Loading model: {config}")
cfg.load_model( config, lora )
else:
raise Exception(f"Unknown config passed: {config}")
2023-08-16 02:58:16 +00:00
2024-11-10 18:19:48 +00:00
cfg.format( training=False )
cfg.dataset.use_hdf5 = False # could use cfg.load_hdf5(), but why would it ever need to be loaded for inferencing
if amp is None:
amp = cfg.inference.amp
if dtype is None or dtype == "auto":
dtype = cfg.inference.weight_dtype
if device is None:
device = cfg.device
cfg.device = device
cfg.mode = "inferencing"
cfg.trainer.backend = cfg.inference.backend
cfg.trainer.weight_dtype = dtype
cfg.inference.weight_dtype = dtype
self.device = device
self.dtype = cfg.inference.dtype
self.amp = amp
self.model_kwargs = {}
if attention:
self.model_kwargs["attention"] = attention
def load_model( self ):
load_engines.cache_clear()
unload_model()
self.engines = load_engines(training=False, **self.model_kwargs)
for name, engine in self.engines.items():
if self.dtype != torch.int8:
engine.to(self.device, dtype=self.dtype if not self.amp else torch.float32)
self.engines.eval()
self.symmap = get_phone_symmap()
_logger.info("Loaded model")
2023-08-02 21:53:35 +00:00
def enable_lora( self, enabled=True ):
for name, engine in self.engines.items():
enable_lora( engine.module, mode = enabled )
def disable_lora( self ):
return self.enable_lora( enabled=False )
2023-08-21 02:36:02 +00:00
def encode_text( self, text, language="en" ):
# already a tensor, return it
if isinstance( text, Tensor ):
return text
content = g2p.encode(text, language=language)
tokens = tokenize( content )
return torch.tensor( tokens )
2023-08-02 21:53:35 +00:00
def encode_lang( self, language ):
symmap = get_lang_symmap()
id = 0
if language in symmap:
id = symmap[language]
return torch.tensor([ id ])
# to-do: trim before quantizing, instead of after
2024-10-17 19:37:21 +00:00
def encode_audio( self, paths, trim_length=5.0 ):
2023-08-21 02:36:02 +00:00
# already a tensor, return it
if isinstance( paths, Tensor ):
return paths
# split string into paths
if isinstance( paths, str ):
paths = [ Path(p) for p in paths.split(";") ]
# merge inputs
proms = []
for path in paths:
prom = qnt.encode_from_file(path)
if hasattr( prom, "codes" ):
prom = prom.codes
prom = prom[0][:, :].t().to(torch.int16)
proms.append( prom )
res = torch.cat(proms)
2023-08-25 04:33:36 +00:00
if trim_length:
2024-10-16 00:30:43 +00:00
res = repeat_extend_audio( res, int( cfg.dataset.frames_per_second * trim_length ) )
2024-10-17 19:37:21 +00:00
#res = trim( res, int( cfg.dataset.frames_per_second * trim_length ) )
2023-08-21 02:36:02 +00:00
2023-08-16 02:58:16 +00:00
return res
2023-08-02 21:53:35 +00:00
@torch.inference_mode()
def text_embedding( self, input, prom=False ):
model = None
for name, engine in self.engines.items():
model = engine.module
break
if isinstance( input, str ):
input = cfg.tokenizer.encode(input)
if isinstance( input, list ):
input = torch.tensor( input, dtype=torch.uint8, device=self.device )
return model.text_emb( input )
@torch.inference_mode()
def audio_embedding( self, input, prom=False ):
model = None
for name, engine in self.engines.items():
model = engine.module
break
# im really not sure which way is the better way, since the proms_emb and resps_emb have different properties.......
if prom:
return model.proms_emb(
input,
quant_level=input.shape[-1] - 1,
offset=0,
sums=True,
)
return sum([ model.resps_emb(
input[:, :l+1],
offset = 0 if l == 0 else 1, # or maybe set to 1
quant_level = l,
sums = False
) for l in range( input.shape[-1] - 1 ) ])
def modality( self, modality ):
# cringe to handle the best default mode for a given model
if modality == "auto" and cfg.model.name in ["ar+nar", "nar-len"]:
modality = cfg.model.name
return modality
2023-08-21 02:36:02 +00:00
@torch.inference_mode()
def inference(
self,
text,
references,
text_language=None,
language="en",
2024-09-06 04:21:18 +00:00
task="tts",
modality="auto",
input_prompt_length = 0,
seed = None,
out_path=None,
tqdm=True,
use_lora=None,
**sampling_kwargs,
):
if not text_language:
text_language = language
lines = sentence_split(text, split_by=sampling_kwargs.get("split_text_by", "sentences"))
wavs = []
sr = None
model_ar = None
model_len = None
model_nar = None
for name, engine in self.engines.items():
2024-11-19 03:29:28 +00:00
if model_ar is None and "ar" in engine.hyper_config.capabilities:
model_ar = engine.module
2024-11-19 03:29:28 +00:00
if model_len is None and "len" in engine.hyper_config.capabilities:
model_len = engine.module
2024-11-19 03:29:28 +00:00
if model_nar is None and "nar" in engine.hyper_config.capabilities:
model_nar = engine.module
seed = set_seed(seed)
modality = self.modality( modality )
# force AR+NAR
if modality == "ar+nar":
model_len = None
# force NAR-len
elif modality == "nar-len":
model_ar = None
2024-09-06 04:21:18 +00:00
if task == "stt":
resp = self.encode_audio( references )
lang = self.encode_lang( language )
resp = to_device(resp, device=self.device, dtype=torch.int16)
2024-09-06 04:21:18 +00:00
lang = to_device(lang, device=self.device, dtype=torch.uint8)
with torch.autocast("cuda", dtype=self.dtype, enabled=self.amp):
model = model_ar if model_ar is not None else model_nar
if model is not None:
text_list = model(
text_list=None, proms_list=[resp], lang_list=[lang], resps_list=[resp], task_list=["stt"],
2024-09-06 04:21:18 +00:00
disable_tqdm=not tqdm,
use_lora=use_lora,
**sampling_kwargs,
2024-09-06 04:21:18 +00:00
)
else:
raise Exception("!")
2024-09-06 23:44:25 +00:00
text_list = [ cfg.tokenizer.decode( text ).replace(" ", "_").replace(" ", "").replace("_", " ") for text in text_list ]
2024-09-06 04:21:18 +00:00
return text_list[0]
# stuff for rolling context
prefix_context = None
prefix_contexts = []
context_history = sampling_kwargs.get("context_history", 0)
for line in lines:
if out_path is None:
output_dir = Path("./data/results/")
if not output_dir.exists():
output_dir.mkdir(parents=True, exist_ok=True)
out_path = output_dir / f"{time.time()}.wav"
prom = self.encode_audio( references, trim_length=input_prompt_length ) if references else None
phns = self.encode_text( line, language=text_language )
lang = self.encode_lang( language )
prom = to_device(prom, device=self.device, dtype=torch.int16)
phns = to_device(phns, device=self.device, dtype=torch.uint8 if len(self.symmap) < 256 else torch.int16)
lang = to_device(lang, device=self.device, dtype=torch.uint8)
# to-do: add in case for experimental.hf model
with torch.autocast("cuda", dtype=self.dtype, enabled=self.amp):
if model_len is not None:
# extra kwargs
duration_padding = sampling_kwargs.pop("duration_padding", 1.05)
nar_len_prefix_length = sampling_kwargs.pop("nar_len_prefix_length", 0)
len_list = model_len( text_list=[phns], proms_list=[prom], task_list=["len"], disable_tqdm=not tqdm, **{"max_duration": 5} ) # "max_duration" is max tokens
# add an additional X seconds
len_list = [ int(l * duration_padding) for l in len_list ]
kwargs = {}
if prefix_context is not None:
kwargs["prefix_context"] = prefix_context
2024-11-10 18:19:48 +00:00
resps_list = model_nar( text_list=[phns], proms_list=[prom], len_list=len_list, task_list=["tts"],
disable_tqdm=not tqdm,
use_lora=use_lora,
**(sampling_kwargs | kwargs),
)
elif model_ar is not None:
kwargs = {}
if prefix_context is not None:
kwargs["prefix_context"] = prefix_context
resps_list = model_ar(
text_list=[phns], proms_list=[prom], lang_list=[lang], task_list=["tts"],
disable_tqdm=not tqdm,
use_lora=use_lora,
**(sampling_kwargs | kwargs),
)
resps_list = model_nar(
text_list=[phns], proms_list=[prom], lang_list=[lang], resps_list=resps_list, task_list=["tts"],
disable_tqdm=not tqdm,
use_lora=use_lora,
**sampling_kwargs,
)
else:
raise Exception("!")
# to-do: care about batching later
resps = resps_list[0]
# store current context to use as the initial input for later
if context_history > 0:
# add to history
prefix_contexts.append(( phns, resps, resps.shape[0] ))
# then generate the prefix based on how much history to provide
prefix_context = (
[ torch.concat( [ x[0] for x in prefix_contexts[-context_history:] ] ) ],
[ torch.concat( [ x[1] for x in prefix_contexts[-context_history:] ] ) ],
[ sum([ x[2] for x in prefix_contexts[-context_history:] ]) ]
)
# write to file
wav, sr = qnt.decode_to_file(resps, out_path, device=self.device)
# add utterances
wavs.append(wav)
# combine all utterances
return (torch.concat(wavs, dim=-1), sr)
2023-08-02 21:53:35 +00:00