inferencing cleanup
This commit is contained in:
parent
a47029065b
commit
7b1b82e0e5
|
@ -2,11 +2,15 @@ import argparse
|
|||
from pathlib import Path
|
||||
from .inference import TTS
|
||||
|
||||
def path_list(arg):
|
||||
return [Path(p) for p in arg.split(";")]
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser("VALL-E TTS")
|
||||
parser.add_argument("text")
|
||||
parser.add_argument("reference", type=Path)
|
||||
parser.add_argument("out_path", type=Path)
|
||||
parser.add_argument("references", type=path_list)
|
||||
parser.add_argument("--out-path", type=Path, default=None)
|
||||
|
||||
parser.add_argument("--yaml", type=Path, default=None)
|
||||
parser.add_argument("--ar-ckpt", type=Path, default=None)
|
||||
parser.add_argument("--nar-ckpt", type=Path, default=None)
|
||||
|
@ -17,7 +21,7 @@ def main():
|
|||
args = parser.parse_args()
|
||||
|
||||
tts = TTS( config=args.yaml, ar_ckpt=args.ar_ckpt, nar_ckpt=args.nar_ckpt, device=args.device )
|
||||
tts.inference( text=args.text, reference=args.reference, out_path=args.out_path, max_ar_steps=args.max_ar_steps, ar_temp=args.ar_temp, nar_temp=args.nar_temp )
|
||||
tts.inference( text=args.text, references=args.references, out_path=args.out_path, max_ar_steps=args.max_ar_steps, ar_temp=args.ar_temp, nar_temp=args.nar_temp )
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
@ -404,9 +404,19 @@ class Trainer:
|
|||
|
||||
@dataclass()
|
||||
class Inference:
|
||||
weight_dtype: str = "float32"
|
||||
|
||||
normalize: bool = False # do NOT enable this unless you know exactly what you're doing
|
||||
use_vocos: bool = True
|
||||
|
||||
@cached_property
|
||||
def dtype(self):
|
||||
if self.weight_dtype == "float16":
|
||||
return torch.float16
|
||||
if self.weight_dtype == "bfloat16":
|
||||
return torch.bfloat16
|
||||
return torch.float32
|
||||
|
||||
@dataclass()
|
||||
class BitsAndBytes:
|
||||
enabled: bool = False
|
||||
|
|
|
@ -28,8 +28,8 @@ from tqdm.auto import tqdm
|
|||
_logger = logging.getLogger(__name__)
|
||||
|
||||
def get_phone_symmap():
|
||||
#if cfg.dataset.use_hdf5 and 'symmap' in cfg.hdf5:
|
||||
# return json.loads( cfg.hdf5['symmap'].asstr()[()] )
|
||||
if cfg.dataset.use_hdf5 and 'symmap' in cfg.hdf5:
|
||||
return json.loads( cfg.hdf5['symmap'].asstr()[()] )
|
||||
|
||||
symmap = {'<s>': 1, '</s>': 2, ' ': 3, '.': 4, ',': 5, '!': 6, '?': 7, 'p': 7, 'iː': 8, 'ɚ': 9, 'ˌ': 10, 'dˌ': 11, 'mˌ': 12, 'd': 13, 'ɹ': 14, 'tˈ': 15, 'pˌ': 16, 'uː': 17, 'l': 18, 'æ': 19, 'ɛ': 20, 'ɪ': 21, 'j': 22, 'ʊ': 23, 't': 24, 'n': 25, 'v': 26, 'a': 27, 'o': 28, 'ŋ': 29, 'w': 30, 'ʌ': 31, 'hˈ': 32, 'ɡˈ': 33, 'ə': 34, 'θˈ': 35, 'dˈ': 36, 'wˌ': 37, 'h': 38, 'z': 39, 'k': 40, 'ð': 41, 'ɡˌ': 42, 'ˈ': 43, 'fˈ': 44, 'i': 45, 's': 46, 'ʃ': 47, 'wˈ': 48, 'ðˈ': 49, 'ɹˈ': 50, 'lˈ': 51, 'ɡ': 52, 'oː': 53, 'mˈ': 54, 'e': 55, 'ɑː': 56, 'nˈ': 57, 'm': 58, 'θˌ': 59, 'sˈ': 60, 'f': 61, 'ɔː': 62, 'hˌ': 63, 'b': 64, 'jˈ': 65, 'ɐ': 66, 'ʒˈ': 67, 'θ': 68, 'bˈ': 69, 'ɾ': 70, 'ɜː': 71, 'ʌˈ': 72, 'ʃˌ': 73, 'bˌ': 74, 'kˈ': 75, 'ɔ': 76, 'zˈ': 77, 'ᵻ': 78, 'kˌ': 79, 'vˈ': 80, 'fˌ': 81, 'ʒ': 82, 'ʃˈ': 83, 'ɹˌ': 84, 'tˌ': 85, 'pˈ': 86, 'ðˌ': 87, 'sˌ': 88, 'nˌ': 89, 'lˌ': 90, '̩': 91, 'ʔ': 92, 'vˌ': 93, 'ɪˈ': 94, '"': 95, 'ɪˌ': 96, 'ʒˌ': 97, 'uːˌ': 98, 'ʊˈ': 99, 'jˌ': 100, 'uːˈ': 101, 'iːˈ': 102, 'zˌ': 103, '.ˈ': 104, '…': 105, 'ŋˌ': 106, 'ɐˌ': 107, '—ˈ': 108, 'iˌ': 109, 'iːˌ': 110, 'ɛː': 111, ')': 112, ')ˈ': 113, '(': 114, 'u': 115, '-': 116, 'ɖˈ': 117, 'iˈ': 118, 'ʰˈ': 119, 'ɟˈ': 120, '̃': 121, 'eː': 122, 'ɾˈ': 123, 'r': 124, 'ʰ': 125, '-ˌ': 126, 'ɫ': 127, 'q': 128, '—': 129, 'ʊˌ': 130, 'aː': 131, 'cˈ': 132, '…ˈ': 133, 'c': 134, 'ɳ': 135, 'ɐˈ': 136, 'x': 137, 'ʔˌ': 138, '.ˌ': 139, 'ɑ': 140, '?ˈ': 141, '̩ˈ': 142, '"ˈ': 143, ',ˈ': 144, 'ŋˈ': 145, 'əˌ': 146, '!ˈ': 147, '"ˌ': 148, '?ˌ': 149, ',ˌ': 150, '—ˌ': 151, '̩ˌ': 152, 'əˈ': 153, '!ˌ': 154, 'ɬ': 155, 'ʲ': 156, '¡': 157, 'ɯ': 158, 'qˌ': 159, 'ʑ': 160, 'ʑˈ': 161, '¿': 162, 'ɑːˈ': 163, 'iːː': 164, 'ɛˈ': 165, '¡ˈ': 166, 'æˈ': 167, 'ç': 168, 'ɾˌ': 169, 'ᵻˈ': 170, 'xˈ': 171, 'ɔːˈ': 172, ';': 173, 'ɬˌ': 174, ':': 175, 'ʔˈ': 176, 'ɑːˌ': 177, 'ɬˈ': 178}
|
||||
return symmap
|
||||
|
@ -67,12 +67,9 @@ def _load_quants(path) -> Tensor:
|
|||
return torch.load(path)[0][:, :].t().to(torch.int16)
|
||||
|
||||
@cache
|
||||
def _get_phones(path, lang_marker="en"):
|
||||
path = _get_phone_path(path)
|
||||
with open(path, "r", encoding="utf8") as f:
|
||||
content = f.read()
|
||||
split = content.split(" ")
|
||||
return [f"<s>"] + [ " " if not p else p for p in split ] + [f"</s>"]
|
||||
def _get_phones(path, language="en"):
|
||||
content = open(_get_phone_path(path), "r", encoding="utf8").read().split(" ")
|
||||
return ["<s>"] + [ " " if not p else p for p in split ] + ["</s>"]
|
||||
|
||||
def _interleaved_reorder(l, fn):
|
||||
groups = defaultdict(list)
|
||||
|
@ -779,6 +776,7 @@ if __name__ == "__main__":
|
|||
continue
|
||||
|
||||
print(text, task, cfg.models.prom_levels)
|
||||
print( proms.shape, resps.shape )
|
||||
decode_to_file( proms, f"./data/{task}.proms.wav", device="cpu" )
|
||||
decode_to_file( resps, f"./data/{task}.resps.wav", device="cpu" )
|
||||
break
|
|
@ -182,14 +182,14 @@ def encode_from_file(path, device="cuda"):
|
|||
|
||||
# trims a random piece of audio, up to `target`
|
||||
def trim_random( qnt, target ):
|
||||
length = qnt.shape[0]
|
||||
length = max( qnt.shape[0], qnt.shape[1] )
|
||||
start = int(length * random.random())
|
||||
end = start + target
|
||||
if end >= length:
|
||||
start = length - target
|
||||
end = length
|
||||
|
||||
return qnt[start:end]
|
||||
return qnt[start:end] if qnt.shape[0] > qnt.shape[1] else qnt[:, start:end]
|
||||
|
||||
# repeats the audio to fit the target size
|
||||
def repeat_extend_audio( qnt, target ):
|
||||
|
|
|
@ -2,7 +2,9 @@ import torch
|
|||
import torchaudio
|
||||
import soundfile
|
||||
|
||||
from torch import Tensor
|
||||
from einops import rearrange
|
||||
from pathlib import Path
|
||||
|
||||
from .emb import g2p, qnt
|
||||
from .emb.qnt import trim_random
|
||||
|
@ -29,16 +31,6 @@ class TTS():
|
|||
except Exception as e:
|
||||
pass
|
||||
|
||||
"""
|
||||
if cfg.trainer.load_state_dict:
|
||||
for model in cfg.models.get():
|
||||
path = cfg.ckpt_dir / model.full_name / "fp32.pth"
|
||||
if model.name.startswith("ar"):
|
||||
ar_ckpt = path
|
||||
if model.name.startswith("nar"):
|
||||
nar_ckpt = path
|
||||
"""
|
||||
|
||||
if ar_ckpt and nar_ckpt:
|
||||
self.ar_ckpt = ar_ckpt
|
||||
self.nar_ckpt = nar_ckpt
|
||||
|
@ -74,22 +66,39 @@ class TTS():
|
|||
elif name[:3] == "nar":
|
||||
self.nar = engine.module.to(self.device)
|
||||
|
||||
def encode_text( self, text, lang_marker="en" ):
|
||||
content = g2p.encode(text)
|
||||
def encode_text( self, text, language="en" ):
|
||||
# already a tensor, return it
|
||||
if isinstance( text, Tensor ):
|
||||
return text
|
||||
|
||||
content = g2p.encode(text, language=language)
|
||||
#phones = ["<s>"] + [ " " if not p else p for p in content ] + ["</s>"]
|
||||
phones = [ " " if not p else p for p in content ]
|
||||
return torch.tensor([ 1 ] + [*map(self.symmap.get, phones)] + [ 2 ])
|
||||
|
||||
def encode_audio( self, path, trim=True ):
|
||||
enc = qnt.encode_from_file( path )
|
||||
res = enc[0].t().to(torch.int16)
|
||||
def encode_audio( self, paths, trim=True ):
|
||||
# already a tensor, return it
|
||||
if isinstance( paths, Tensor ):
|
||||
return paths
|
||||
|
||||
# split string into paths
|
||||
if isinstance( paths, str ):
|
||||
paths = [ Path(p) for p in paths.split(";") ]
|
||||
|
||||
# merge inputs
|
||||
res = torch.cat([qnt.encode_from_file( path )[0].t().to(torch.int16) for path in paths])
|
||||
|
||||
if trim:
|
||||
res = trim_random( res, int( 75 * cfg.dataset.duration_range[1] ) )
|
||||
res = trim_random( res, int( 75 * cfg.dataset.prompt_duration ) )
|
||||
|
||||
return res
|
||||
|
||||
@torch.inference_mode()
|
||||
def inference( self, text, references, max_ar_steps=6 * 75, ar_temp=1.0, nar_temp=1.0, out_path=None ):
|
||||
if out_path is None:
|
||||
out_path = f"./data/{text}.wav"
|
||||
|
||||
def inference( self, text, reference, max_ar_steps=6 * 75, ar_temp=1.0, nar_temp=1.0, out_path="./.tmp.wav" ):
|
||||
prom = self.encode_audio( reference )
|
||||
prom = self.encode_audio( references )
|
||||
phns = self.encode_text( text )
|
||||
|
||||
prom = to_device(prom, self.device).to(torch.int16)
|
||||
|
|
Loading…
Reference in New Issue
Block a user