import torch import torchaudio import soundfile from torch import Tensor from einops import rearrange from pathlib import Path from .emb import g2p, qnt from .emb.qnt import trim, trim_random from .utils import to_device from .config import cfg from .models import get_models from .engines import load_engines, deepspeed_available from .data import get_phone_symmap, get_lang_symmap, _load_quants, _cleanup_phones, tokenize if deepspeed_available: import deepspeed class TTS(): def __init__( self, config=None, device=None, amp=None, dtype=None ): self.loading = True self.input_sample_rate = 24000 self.output_sample_rate = 24000 if config: cfg.load_yaml( config ) try: cfg.format( training=False ) cfg.dataset.use_hdf5 = False # could use cfg.load_hdf5(), but why would it ever need to be loaded for inferencing except Exception as e: print("Error while parsing config YAML:") raise e # throw an error because I'm tired of silent errors messing things up for me if amp is None: amp = cfg.inference.amp if dtype is None or dtype == "auto": dtype = cfg.inference.weight_dtype if device is None: device = cfg.device cfg.device = device cfg.mode = "inferencing" cfg.trainer.backend = cfg.inference.backend cfg.trainer.weight_dtype = dtype cfg.inference.weight_dtype = dtype self.device = device self.dtype = cfg.inference.dtype self.amp = amp self.symmap = None self.engines = load_engines(training=False) for name, engine in self.engines.items(): if self.dtype != torch.int8: engine.to(self.device, dtype=self.dtype if not self.amp else torch.float32) self.engines.eval() if self.symmap is None: self.symmap = get_phone_symmap() self.loading = False def encode_text( self, text, language="en" ): # already a tensor, return it if isinstance( text, Tensor ): return text content = g2p.encode(text, language=language) tokens = tokenize( content ) return torch.tensor( tokens ) def encode_lang( self, language ): symmap = get_lang_symmap() id = 0 if language in symmap: id = symmap[language] return torch.tensor([ id ]) def encode_audio( self, paths, trim_length=0.0 ): # already a tensor, return it if isinstance( paths, Tensor ): return paths # split string into paths if isinstance( paths, str ): paths = [ Path(p) for p in paths.split(";") ] # merge inputs proms = [] for path in paths: prom = qnt.encode_from_file(path) if hasattr( prom, "codes" ): prom = prom.codes proms.append( prom ) res = torch.cat(proms) if trim_length: res = trim( res, int( cfg.dataset.frames_per_second * trim_length ) ) return res @torch.inference_mode() def inference( self, text, references, language="en", max_ar_steps=6 * cfg.dataset.frames_per_second, max_ar_context=-1, input_prompt_length=0.0, ar_temp=0.95, diffusion_temp=0.5, min_ar_temp=0.95, min_diffusion_temp=0.5, top_p=1.0, top_k=0, repetition_penalty=1.0, repetition_penalty_decay=0.0, length_penalty=0.0, beam_width=0, mirostat_tau=0, mirostat_eta=0.1, out_path=None ): lines = text.split("\n") wavs = [] sr = None for name, engine in self.engines.items(): ... for line in lines: if out_path is None: out_path = f"./data/{cfg.start_time}.wav" ... wav, sr = qnt.decode_to_file(resps_list[0], out_path, device=self.device) wavs.append(wav) return (torch.concat(wavs, dim=-1), sr)