inferencing cleanup

This commit is contained in:
mrq 2023-08-20 21:36:02 -05:00
parent a47029065b
commit 7b1b82e0e5
5 changed files with 52 additions and 31 deletions

View File

@ -2,11 +2,15 @@ import argparse
from pathlib import Path
from .inference import TTS
def path_list(arg):
return [Path(p) for p in arg.split(";")]
def main():
parser = argparse.ArgumentParser("VALL-E TTS")
parser.add_argument("text")
parser.add_argument("reference", type=Path)
parser.add_argument("out_path", type=Path)
parser.add_argument("references", type=path_list)
parser.add_argument("--out-path", type=Path, default=None)
parser.add_argument("--yaml", type=Path, default=None)
parser.add_argument("--ar-ckpt", type=Path, default=None)
parser.add_argument("--nar-ckpt", type=Path, default=None)
@ -17,7 +21,7 @@ def main():
args = parser.parse_args()
tts = TTS( config=args.yaml, ar_ckpt=args.ar_ckpt, nar_ckpt=args.nar_ckpt, device=args.device )
tts.inference( text=args.text, reference=args.reference, out_path=args.out_path, max_ar_steps=args.max_ar_steps, ar_temp=args.ar_temp, nar_temp=args.nar_temp )
tts.inference( text=args.text, references=args.references, out_path=args.out_path, max_ar_steps=args.max_ar_steps, ar_temp=args.ar_temp, nar_temp=args.nar_temp )
if __name__ == "__main__":
main()

View File

@ -404,9 +404,19 @@ class Trainer:
@dataclass()
class Inference:
weight_dtype: str = "float32"
normalize: bool = False # do NOT enable this unless you know exactly what you're doing
use_vocos: bool = True
@cached_property
def dtype(self):
if self.weight_dtype == "float16":
return torch.float16
if self.weight_dtype == "bfloat16":
return torch.bfloat16
return torch.float32
@dataclass()
class BitsAndBytes:
enabled: bool = False

View File

@ -28,8 +28,8 @@ from tqdm.auto import tqdm
_logger = logging.getLogger(__name__)
def get_phone_symmap():
#if cfg.dataset.use_hdf5 and 'symmap' in cfg.hdf5:
# return json.loads( cfg.hdf5['symmap'].asstr()[()] )
if cfg.dataset.use_hdf5 and 'symmap' in cfg.hdf5:
return json.loads( cfg.hdf5['symmap'].asstr()[()] )
symmap = {'<s>': 1, '</s>': 2, ' ': 3, '.': 4, ',': 5, '!': 6, '?': 7, 'p': 7, 'iː': 8, 'ɚ': 9, 'ˌ': 10, '': 11, '': 12, 'd': 13, 'ɹ': 14, 'tˈ': 15, '': 16, 'uː': 17, 'l': 18, 'æ': 19, 'ɛ': 20, 'ɪ': 21, 'j': 22, 'ʊ': 23, 't': 24, 'n': 25, 'v': 26, 'a': 27, 'o': 28, 'ŋ': 29, 'w': 30, 'ʌ': 31, 'hˈ': 32, 'ɡˈ': 33, 'ə': 34, 'θˈ': 35, 'dˈ': 36, '': 37, 'h': 38, 'z': 39, 'k': 40, 'ð': 41, 'ɡˌ': 42, 'ˈ': 43, 'fˈ': 44, 'i': 45, 's': 46, 'ʃ': 47, 'wˈ': 48, 'ðˈ': 49, 'ɹˈ': 50, 'lˈ': 51, 'ɡ': 52, 'oː': 53, 'mˈ': 54, 'e': 55, 'ɑː': 56, 'nˈ': 57, 'm': 58, 'θˌ': 59, 'sˈ': 60, 'f': 61, 'ɔː': 62, '': 63, 'b': 64, 'jˈ': 65, 'ɐ': 66, 'ʒˈ': 67, 'θ': 68, 'bˈ': 69, 'ɾ': 70, 'ɜː': 71, 'ʌˈ': 72, 'ʃˌ': 73, '': 74, 'kˈ': 75, 'ɔ': 76, 'zˈ': 77, '': 78, '': 79, 'vˈ': 80, '': 81, 'ʒ': 82, 'ʃˈ': 83, 'ɹˌ': 84, '': 85, 'pˈ': 86, 'ðˌ': 87, '': 88, '': 89, '': 90, '̩': 91, 'ʔ': 92, '': 93, 'ɪˈ': 94, '"': 95, 'ɪˌ': 96, 'ʒˌ': 97, 'uːˌ': 98, 'ʊˈ': 99, '': 100, 'uːˈ': 101, 'iːˈ': 102, '': 103, '.ˈ': 104, '': 105, 'ŋˌ': 106, 'ɐˌ': 107, '—ˈ': 108, '': 109, 'iːˌ': 110, 'ɛː': 111, ')': 112, ')ˈ': 113, '(': 114, 'u': 115, '-': 116, 'ɖˈ': 117, 'iˈ': 118, 'ʰˈ': 119, 'ɟˈ': 120, '̃': 121, 'eː': 122, 'ɾˈ': 123, 'r': 124, 'ʰ': 125, '': 126, 'ɫ': 127, 'q': 128, '': 129, 'ʊˌ': 130, 'aː': 131, 'cˈ': 132, '…ˈ': 133, 'c': 134, 'ɳ': 135, 'ɐˈ': 136, 'x': 137, 'ʔˌ': 138, '': 139, 'ɑ': 140, '?ˈ': 141, '̩ˈ': 142, '"ˈ': 143, ',ˈ': 144, 'ŋˈ': 145, 'əˌ': 146, '!ˈ': 147, '"ˌ': 148, '': 149, '': 150, '—ˌ': 151, '̩ˌ': 152, 'əˈ': 153, '': 154, 'ɬ': 155, 'ʲ': 156, '¡': 157, 'ɯ': 158, '': 159, 'ʑ': 160, 'ʑˈ': 161, '¿': 162, 'ɑːˈ': 163, 'iːː': 164, 'ɛˈ': 165, '¡ˈ': 166, 'æˈ': 167, 'ç': 168, 'ɾˌ': 169, 'ᵻˈ': 170, 'xˈ': 171, 'ɔːˈ': 172, ';': 173, 'ɬˌ': 174, ':': 175, 'ʔˈ': 176, 'ɑːˌ': 177, 'ɬˈ': 178}
return symmap
@ -67,12 +67,9 @@ def _load_quants(path) -> Tensor:
return torch.load(path)[0][:, :].t().to(torch.int16)
@cache
def _get_phones(path, lang_marker="en"):
path = _get_phone_path(path)
with open(path, "r", encoding="utf8") as f:
content = f.read()
split = content.split(" ")
return [f"<s>"] + [ " " if not p else p for p in split ] + [f"</s>"]
def _get_phones(path, language="en"):
content = open(_get_phone_path(path), "r", encoding="utf8").read().split(" ")
return ["<s>"] + [ " " if not p else p for p in split ] + ["</s>"]
def _interleaved_reorder(l, fn):
groups = defaultdict(list)
@ -779,6 +776,7 @@ if __name__ == "__main__":
continue
print(text, task, cfg.models.prom_levels)
print( proms.shape, resps.shape )
decode_to_file( proms, f"./data/{task}.proms.wav", device="cpu" )
decode_to_file( resps, f"./data/{task}.resps.wav", device="cpu" )
break

View File

@ -182,14 +182,14 @@ def encode_from_file(path, device="cuda"):
# trims a random piece of audio, up to `target`
def trim_random( qnt, target ):
length = qnt.shape[0]
length = max( qnt.shape[0], qnt.shape[1] )
start = int(length * random.random())
end = start + target
if end >= length:
start = length - target
end = length
return qnt[start:end]
return qnt[start:end] if qnt.shape[0] > qnt.shape[1] else qnt[:, start:end]
# repeats the audio to fit the target size
def repeat_extend_audio( qnt, target ):

View File

@ -2,7 +2,9 @@ import torch
import torchaudio
import soundfile
from torch import Tensor
from einops import rearrange
from pathlib import Path
from .emb import g2p, qnt
from .emb.qnt import trim_random
@ -28,16 +30,6 @@ class TTS():
cfg.format()
except Exception as e:
pass
"""
if cfg.trainer.load_state_dict:
for model in cfg.models.get():
path = cfg.ckpt_dir / model.full_name / "fp32.pth"
if model.name.startswith("ar"):
ar_ckpt = path
if model.name.startswith("nar"):
nar_ckpt = path
"""
if ar_ckpt and nar_ckpt:
self.ar_ckpt = ar_ckpt
@ -74,22 +66,39 @@ class TTS():
elif name[:3] == "nar":
self.nar = engine.module.to(self.device)
def encode_text( self, text, lang_marker="en" ):
content = g2p.encode(text)
def encode_text( self, text, language="en" ):
# already a tensor, return it
if isinstance( text, Tensor ):
return text
content = g2p.encode(text, language=language)
#phones = ["<s>"] + [ " " if not p else p for p in content ] + ["</s>"]
phones = [ " " if not p else p for p in content ]
return torch.tensor([ 1 ] + [*map(self.symmap.get, phones)] + [ 2 ])
def encode_audio( self, path, trim=True ):
enc = qnt.encode_from_file( path )
res = enc[0].t().to(torch.int16)
def encode_audio( self, paths, trim=True ):
# already a tensor, return it
if isinstance( paths, Tensor ):
return paths
# split string into paths
if isinstance( paths, str ):
paths = [ Path(p) for p in paths.split(";") ]
# merge inputs
res = torch.cat([qnt.encode_from_file( path )[0].t().to(torch.int16) for path in paths])
if trim:
res = trim_random( res, int( 75 * cfg.dataset.duration_range[1] ) )
res = trim_random( res, int( 75 * cfg.dataset.prompt_duration ) )
return res
@torch.inference_mode()
def inference( self, text, references, max_ar_steps=6 * 75, ar_temp=1.0, nar_temp=1.0, out_path=None ):
if out_path is None:
out_path = f"./data/{text}.wav"
def inference( self, text, reference, max_ar_steps=6 * 75, ar_temp=1.0, nar_temp=1.0, out_path="./.tmp.wav" ):
prom = self.encode_audio( reference )
prom = self.encode_audio( references )
phns = self.encode_text( text )
prom = to_device(prom, self.device).to(torch.int16)