diff --git a/vall_e/__main__.py b/vall_e/__main__.py index 91673a1..e1cd269 100755 --- a/vall_e/__main__.py +++ b/vall_e/__main__.py @@ -2,11 +2,15 @@ import argparse from pathlib import Path from .inference import TTS +def path_list(arg): + return [Path(p) for p in arg.split(";")] + def main(): parser = argparse.ArgumentParser("VALL-E TTS") parser.add_argument("text") - parser.add_argument("reference", type=Path) - parser.add_argument("out_path", type=Path) + parser.add_argument("references", type=path_list) + parser.add_argument("--out-path", type=Path, default=None) + parser.add_argument("--yaml", type=Path, default=None) parser.add_argument("--ar-ckpt", type=Path, default=None) parser.add_argument("--nar-ckpt", type=Path, default=None) @@ -17,7 +21,7 @@ def main(): args = parser.parse_args() tts = TTS( config=args.yaml, ar_ckpt=args.ar_ckpt, nar_ckpt=args.nar_ckpt, device=args.device ) - tts.inference( text=args.text, reference=args.reference, out_path=args.out_path, max_ar_steps=args.max_ar_steps, ar_temp=args.ar_temp, nar_temp=args.nar_temp ) + tts.inference( text=args.text, references=args.references, out_path=args.out_path, max_ar_steps=args.max_ar_steps, ar_temp=args.ar_temp, nar_temp=args.nar_temp ) if __name__ == "__main__": main() diff --git a/vall_e/config.py b/vall_e/config.py index 9205814..f152ae7 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -404,9 +404,19 @@ class Trainer: @dataclass() class Inference: + weight_dtype: str = "float32" + normalize: bool = False # do NOT enable this unless you know exactly what you're doing use_vocos: bool = True + @cached_property + def dtype(self): + if self.weight_dtype == "float16": + return torch.float16 + if self.weight_dtype == "bfloat16": + return torch.bfloat16 + return torch.float32 + @dataclass() class BitsAndBytes: enabled: bool = False diff --git a/vall_e/data.py b/vall_e/data.py index ff49c79..ce7237c 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -28,8 +28,8 @@ from tqdm.auto import tqdm _logger = logging.getLogger(__name__) def get_phone_symmap(): - #if cfg.dataset.use_hdf5 and 'symmap' in cfg.hdf5: - # return json.loads( cfg.hdf5['symmap'].asstr()[()] ) + if cfg.dataset.use_hdf5 and 'symmap' in cfg.hdf5: + return json.loads( cfg.hdf5['symmap'].asstr()[()] ) symmap = {'': 1, '': 2, ' ': 3, '.': 4, ',': 5, '!': 6, '?': 7, 'p': 7, 'iː': 8, 'ɚ': 9, 'ˌ': 10, 'dˌ': 11, 'mˌ': 12, 'd': 13, 'ɹ': 14, 'tˈ': 15, 'pˌ': 16, 'uː': 17, 'l': 18, 'æ': 19, 'ɛ': 20, 'ɪ': 21, 'j': 22, 'ʊ': 23, 't': 24, 'n': 25, 'v': 26, 'a': 27, 'o': 28, 'ŋ': 29, 'w': 30, 'ʌ': 31, 'hˈ': 32, 'ɡˈ': 33, 'ə': 34, 'θˈ': 35, 'dˈ': 36, 'wˌ': 37, 'h': 38, 'z': 39, 'k': 40, 'ð': 41, 'ɡˌ': 42, 'ˈ': 43, 'fˈ': 44, 'i': 45, 's': 46, 'ʃ': 47, 'wˈ': 48, 'ðˈ': 49, 'ɹˈ': 50, 'lˈ': 51, 'ɡ': 52, 'oː': 53, 'mˈ': 54, 'e': 55, 'ɑː': 56, 'nˈ': 57, 'm': 58, 'θˌ': 59, 'sˈ': 60, 'f': 61, 'ɔː': 62, 'hˌ': 63, 'b': 64, 'jˈ': 65, 'ɐ': 66, 'ʒˈ': 67, 'θ': 68, 'bˈ': 69, 'ɾ': 70, 'ɜː': 71, 'ʌˈ': 72, 'ʃˌ': 73, 'bˌ': 74, 'kˈ': 75, 'ɔ': 76, 'zˈ': 77, 'ᵻ': 78, 'kˌ': 79, 'vˈ': 80, 'fˌ': 81, 'ʒ': 82, 'ʃˈ': 83, 'ɹˌ': 84, 'tˌ': 85, 'pˈ': 86, 'ðˌ': 87, 'sˌ': 88, 'nˌ': 89, 'lˌ': 90, '̩': 91, 'ʔ': 92, 'vˌ': 93, 'ɪˈ': 94, '"': 95, 'ɪˌ': 96, 'ʒˌ': 97, 'uːˌ': 98, 'ʊˈ': 99, 'jˌ': 100, 'uːˈ': 101, 'iːˈ': 102, 'zˌ': 103, '.ˈ': 104, '…': 105, 'ŋˌ': 106, 'ɐˌ': 107, '—ˈ': 108, 'iˌ': 109, 'iːˌ': 110, 'ɛː': 111, ')': 112, ')ˈ': 113, '(': 114, 'u': 115, '-': 116, 'ɖˈ': 117, 'iˈ': 118, 'ʰˈ': 119, 'ɟˈ': 120, '̃': 121, 'eː': 122, 'ɾˈ': 123, 'r': 124, 'ʰ': 125, '-ˌ': 126, 'ɫ': 127, 'q': 128, '—': 129, 'ʊˌ': 130, 'aː': 131, 'cˈ': 132, '…ˈ': 133, 'c': 134, 'ɳ': 135, 'ɐˈ': 136, 'x': 137, 'ʔˌ': 138, '.ˌ': 139, 'ɑ': 140, '?ˈ': 141, '̩ˈ': 142, '"ˈ': 143, ',ˈ': 144, 'ŋˈ': 145, 'əˌ': 146, '!ˈ': 147, '"ˌ': 148, '?ˌ': 149, ',ˌ': 150, '—ˌ': 151, '̩ˌ': 152, 'əˈ': 153, '!ˌ': 154, 'ɬ': 155, 'ʲ': 156, '¡': 157, 'ɯ': 158, 'qˌ': 159, 'ʑ': 160, 'ʑˈ': 161, '¿': 162, 'ɑːˈ': 163, 'iːː': 164, 'ɛˈ': 165, '¡ˈ': 166, 'æˈ': 167, 'ç': 168, 'ɾˌ': 169, 'ᵻˈ': 170, 'xˈ': 171, 'ɔːˈ': 172, ';': 173, 'ɬˌ': 174, ':': 175, 'ʔˈ': 176, 'ɑːˌ': 177, 'ɬˈ': 178} return symmap @@ -67,12 +67,9 @@ def _load_quants(path) -> Tensor: return torch.load(path)[0][:, :].t().to(torch.int16) @cache -def _get_phones(path, lang_marker="en"): - path = _get_phone_path(path) - with open(path, "r", encoding="utf8") as f: - content = f.read() - split = content.split(" ") - return [f""] + [ " " if not p else p for p in split ] + [f""] +def _get_phones(path, language="en"): + content = open(_get_phone_path(path), "r", encoding="utf8").read().split(" ") + return [""] + [ " " if not p else p for p in split ] + [""] def _interleaved_reorder(l, fn): groups = defaultdict(list) @@ -779,6 +776,7 @@ if __name__ == "__main__": continue print(text, task, cfg.models.prom_levels) + print( proms.shape, resps.shape ) decode_to_file( proms, f"./data/{task}.proms.wav", device="cpu" ) decode_to_file( resps, f"./data/{task}.resps.wav", device="cpu" ) break \ No newline at end of file diff --git a/vall_e/emb/qnt.py b/vall_e/emb/qnt.py index 90164ab..7985b85 100755 --- a/vall_e/emb/qnt.py +++ b/vall_e/emb/qnt.py @@ -182,14 +182,14 @@ def encode_from_file(path, device="cuda"): # trims a random piece of audio, up to `target` def trim_random( qnt, target ): - length = qnt.shape[0] + length = max( qnt.shape[0], qnt.shape[1] ) start = int(length * random.random()) end = start + target if end >= length: start = length - target end = length - return qnt[start:end] + return qnt[start:end] if qnt.shape[0] > qnt.shape[1] else qnt[:, start:end] # repeats the audio to fit the target size def repeat_extend_audio( qnt, target ): diff --git a/vall_e/inference.py b/vall_e/inference.py index 5b18164..e481f50 100755 --- a/vall_e/inference.py +++ b/vall_e/inference.py @@ -2,7 +2,9 @@ import torch import torchaudio import soundfile +from torch import Tensor from einops import rearrange +from pathlib import Path from .emb import g2p, qnt from .emb.qnt import trim_random @@ -28,16 +30,6 @@ class TTS(): cfg.format() except Exception as e: pass - - """ - if cfg.trainer.load_state_dict: - for model in cfg.models.get(): - path = cfg.ckpt_dir / model.full_name / "fp32.pth" - if model.name.startswith("ar"): - ar_ckpt = path - if model.name.startswith("nar"): - nar_ckpt = path - """ if ar_ckpt and nar_ckpt: self.ar_ckpt = ar_ckpt @@ -74,22 +66,39 @@ class TTS(): elif name[:3] == "nar": self.nar = engine.module.to(self.device) - def encode_text( self, text, lang_marker="en" ): - content = g2p.encode(text) + def encode_text( self, text, language="en" ): + # already a tensor, return it + if isinstance( text, Tensor ): + return text + + content = g2p.encode(text, language=language) #phones = [""] + [ " " if not p else p for p in content ] + [""] phones = [ " " if not p else p for p in content ] return torch.tensor([ 1 ] + [*map(self.symmap.get, phones)] + [ 2 ]) - def encode_audio( self, path, trim=True ): - enc = qnt.encode_from_file( path ) - res = enc[0].t().to(torch.int16) + def encode_audio( self, paths, trim=True ): + # already a tensor, return it + if isinstance( paths, Tensor ): + return paths + + # split string into paths + if isinstance( paths, str ): + paths = [ Path(p) for p in paths.split(";") ] + + # merge inputs + res = torch.cat([qnt.encode_from_file( path )[0].t().to(torch.int16) for path in paths]) + if trim: - res = trim_random( res, int( 75 * cfg.dataset.duration_range[1] ) ) + res = trim_random( res, int( 75 * cfg.dataset.prompt_duration ) ) + return res + @torch.inference_mode() + def inference( self, text, references, max_ar_steps=6 * 75, ar_temp=1.0, nar_temp=1.0, out_path=None ): + if out_path is None: + out_path = f"./data/{text}.wav" - def inference( self, text, reference, max_ar_steps=6 * 75, ar_temp=1.0, nar_temp=1.0, out_path="./.tmp.wav" ): - prom = self.encode_audio( reference ) + prom = self.encode_audio( references ) phns = self.encode_text( text ) prom = to_device(prom, self.device).to(torch.int16)