inferencing cleanup

This commit is contained in:
mrq 2023-08-20 21:36:02 -05:00
parent a47029065b
commit 7b1b82e0e5
5 changed files with 52 additions and 31 deletions

View File

@ -2,11 +2,15 @@ import argparse
from pathlib import Path from pathlib import Path
from .inference import TTS from .inference import TTS
def path_list(arg):
return [Path(p) for p in arg.split(";")]
def main(): def main():
parser = argparse.ArgumentParser("VALL-E TTS") parser = argparse.ArgumentParser("VALL-E TTS")
parser.add_argument("text") parser.add_argument("text")
parser.add_argument("reference", type=Path) parser.add_argument("references", type=path_list)
parser.add_argument("out_path", type=Path) parser.add_argument("--out-path", type=Path, default=None)
parser.add_argument("--yaml", type=Path, default=None) parser.add_argument("--yaml", type=Path, default=None)
parser.add_argument("--ar-ckpt", type=Path, default=None) parser.add_argument("--ar-ckpt", type=Path, default=None)
parser.add_argument("--nar-ckpt", type=Path, default=None) parser.add_argument("--nar-ckpt", type=Path, default=None)
@ -17,7 +21,7 @@ def main():
args = parser.parse_args() args = parser.parse_args()
tts = TTS( config=args.yaml, ar_ckpt=args.ar_ckpt, nar_ckpt=args.nar_ckpt, device=args.device ) tts = TTS( config=args.yaml, ar_ckpt=args.ar_ckpt, nar_ckpt=args.nar_ckpt, device=args.device )
tts.inference( text=args.text, reference=args.reference, out_path=args.out_path, max_ar_steps=args.max_ar_steps, ar_temp=args.ar_temp, nar_temp=args.nar_temp ) tts.inference( text=args.text, references=args.references, out_path=args.out_path, max_ar_steps=args.max_ar_steps, ar_temp=args.ar_temp, nar_temp=args.nar_temp )
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -404,9 +404,19 @@ class Trainer:
@dataclass() @dataclass()
class Inference: class Inference:
weight_dtype: str = "float32"
normalize: bool = False # do NOT enable this unless you know exactly what you're doing normalize: bool = False # do NOT enable this unless you know exactly what you're doing
use_vocos: bool = True use_vocos: bool = True
@cached_property
def dtype(self):
if self.weight_dtype == "float16":
return torch.float16
if self.weight_dtype == "bfloat16":
return torch.bfloat16
return torch.float32
@dataclass() @dataclass()
class BitsAndBytes: class BitsAndBytes:
enabled: bool = False enabled: bool = False

View File

@ -28,8 +28,8 @@ from tqdm.auto import tqdm
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
def get_phone_symmap(): def get_phone_symmap():
#if cfg.dataset.use_hdf5 and 'symmap' in cfg.hdf5: if cfg.dataset.use_hdf5 and 'symmap' in cfg.hdf5:
# return json.loads( cfg.hdf5['symmap'].asstr()[()] ) return json.loads( cfg.hdf5['symmap'].asstr()[()] )
symmap = {'<s>': 1, '</s>': 2, ' ': 3, '.': 4, ',': 5, '!': 6, '?': 7, 'p': 7, 'iː': 8, 'ɚ': 9, 'ˌ': 10, '': 11, '': 12, 'd': 13, 'ɹ': 14, 'tˈ': 15, '': 16, 'uː': 17, 'l': 18, 'æ': 19, 'ɛ': 20, 'ɪ': 21, 'j': 22, 'ʊ': 23, 't': 24, 'n': 25, 'v': 26, 'a': 27, 'o': 28, 'ŋ': 29, 'w': 30, 'ʌ': 31, 'hˈ': 32, 'ɡˈ': 33, 'ə': 34, 'θˈ': 35, 'dˈ': 36, '': 37, 'h': 38, 'z': 39, 'k': 40, 'ð': 41, 'ɡˌ': 42, 'ˈ': 43, 'fˈ': 44, 'i': 45, 's': 46, 'ʃ': 47, 'wˈ': 48, 'ðˈ': 49, 'ɹˈ': 50, 'lˈ': 51, 'ɡ': 52, 'oː': 53, 'mˈ': 54, 'e': 55, 'ɑː': 56, 'nˈ': 57, 'm': 58, 'θˌ': 59, 'sˈ': 60, 'f': 61, 'ɔː': 62, '': 63, 'b': 64, 'jˈ': 65, 'ɐ': 66, 'ʒˈ': 67, 'θ': 68, 'bˈ': 69, 'ɾ': 70, 'ɜː': 71, 'ʌˈ': 72, 'ʃˌ': 73, '': 74, 'kˈ': 75, 'ɔ': 76, 'zˈ': 77, '': 78, '': 79, 'vˈ': 80, '': 81, 'ʒ': 82, 'ʃˈ': 83, 'ɹˌ': 84, '': 85, 'pˈ': 86, 'ðˌ': 87, '': 88, '': 89, '': 90, '̩': 91, 'ʔ': 92, '': 93, 'ɪˈ': 94, '"': 95, 'ɪˌ': 96, 'ʒˌ': 97, 'uːˌ': 98, 'ʊˈ': 99, '': 100, 'uːˈ': 101, 'iːˈ': 102, '': 103, '.ˈ': 104, '': 105, 'ŋˌ': 106, 'ɐˌ': 107, '—ˈ': 108, '': 109, 'iːˌ': 110, 'ɛː': 111, ')': 112, ')ˈ': 113, '(': 114, 'u': 115, '-': 116, 'ɖˈ': 117, 'iˈ': 118, 'ʰˈ': 119, 'ɟˈ': 120, '̃': 121, 'eː': 122, 'ɾˈ': 123, 'r': 124, 'ʰ': 125, '': 126, 'ɫ': 127, 'q': 128, '': 129, 'ʊˌ': 130, 'aː': 131, 'cˈ': 132, '…ˈ': 133, 'c': 134, 'ɳ': 135, 'ɐˈ': 136, 'x': 137, 'ʔˌ': 138, '': 139, 'ɑ': 140, '?ˈ': 141, '̩ˈ': 142, '"ˈ': 143, ',ˈ': 144, 'ŋˈ': 145, 'əˌ': 146, '!ˈ': 147, '"ˌ': 148, '': 149, '': 150, '—ˌ': 151, '̩ˌ': 152, 'əˈ': 153, '': 154, 'ɬ': 155, 'ʲ': 156, '¡': 157, 'ɯ': 158, '': 159, 'ʑ': 160, 'ʑˈ': 161, '¿': 162, 'ɑːˈ': 163, 'iːː': 164, 'ɛˈ': 165, '¡ˈ': 166, 'æˈ': 167, 'ç': 168, 'ɾˌ': 169, 'ᵻˈ': 170, 'xˈ': 171, 'ɔːˈ': 172, ';': 173, 'ɬˌ': 174, ':': 175, 'ʔˈ': 176, 'ɑːˌ': 177, 'ɬˈ': 178} symmap = {'<s>': 1, '</s>': 2, ' ': 3, '.': 4, ',': 5, '!': 6, '?': 7, 'p': 7, 'iː': 8, 'ɚ': 9, 'ˌ': 10, '': 11, '': 12, 'd': 13, 'ɹ': 14, 'tˈ': 15, '': 16, 'uː': 17, 'l': 18, 'æ': 19, 'ɛ': 20, 'ɪ': 21, 'j': 22, 'ʊ': 23, 't': 24, 'n': 25, 'v': 26, 'a': 27, 'o': 28, 'ŋ': 29, 'w': 30, 'ʌ': 31, 'hˈ': 32, 'ɡˈ': 33, 'ə': 34, 'θˈ': 35, 'dˈ': 36, '': 37, 'h': 38, 'z': 39, 'k': 40, 'ð': 41, 'ɡˌ': 42, 'ˈ': 43, 'fˈ': 44, 'i': 45, 's': 46, 'ʃ': 47, 'wˈ': 48, 'ðˈ': 49, 'ɹˈ': 50, 'lˈ': 51, 'ɡ': 52, 'oː': 53, 'mˈ': 54, 'e': 55, 'ɑː': 56, 'nˈ': 57, 'm': 58, 'θˌ': 59, 'sˈ': 60, 'f': 61, 'ɔː': 62, '': 63, 'b': 64, 'jˈ': 65, 'ɐ': 66, 'ʒˈ': 67, 'θ': 68, 'bˈ': 69, 'ɾ': 70, 'ɜː': 71, 'ʌˈ': 72, 'ʃˌ': 73, '': 74, 'kˈ': 75, 'ɔ': 76, 'zˈ': 77, '': 78, '': 79, 'vˈ': 80, '': 81, 'ʒ': 82, 'ʃˈ': 83, 'ɹˌ': 84, '': 85, 'pˈ': 86, 'ðˌ': 87, '': 88, '': 89, '': 90, '̩': 91, 'ʔ': 92, '': 93, 'ɪˈ': 94, '"': 95, 'ɪˌ': 96, 'ʒˌ': 97, 'uːˌ': 98, 'ʊˈ': 99, '': 100, 'uːˈ': 101, 'iːˈ': 102, '': 103, '.ˈ': 104, '': 105, 'ŋˌ': 106, 'ɐˌ': 107, '—ˈ': 108, '': 109, 'iːˌ': 110, 'ɛː': 111, ')': 112, ')ˈ': 113, '(': 114, 'u': 115, '-': 116, 'ɖˈ': 117, 'iˈ': 118, 'ʰˈ': 119, 'ɟˈ': 120, '̃': 121, 'eː': 122, 'ɾˈ': 123, 'r': 124, 'ʰ': 125, '': 126, 'ɫ': 127, 'q': 128, '': 129, 'ʊˌ': 130, 'aː': 131, 'cˈ': 132, '…ˈ': 133, 'c': 134, 'ɳ': 135, 'ɐˈ': 136, 'x': 137, 'ʔˌ': 138, '': 139, 'ɑ': 140, '?ˈ': 141, '̩ˈ': 142, '"ˈ': 143, ',ˈ': 144, 'ŋˈ': 145, 'əˌ': 146, '!ˈ': 147, '"ˌ': 148, '': 149, '': 150, '—ˌ': 151, '̩ˌ': 152, 'əˈ': 153, '': 154, 'ɬ': 155, 'ʲ': 156, '¡': 157, 'ɯ': 158, '': 159, 'ʑ': 160, 'ʑˈ': 161, '¿': 162, 'ɑːˈ': 163, 'iːː': 164, 'ɛˈ': 165, '¡ˈ': 166, 'æˈ': 167, 'ç': 168, 'ɾˌ': 169, 'ᵻˈ': 170, 'xˈ': 171, 'ɔːˈ': 172, ';': 173, 'ɬˌ': 174, ':': 175, 'ʔˈ': 176, 'ɑːˌ': 177, 'ɬˈ': 178}
return symmap return symmap
@ -67,12 +67,9 @@ def _load_quants(path) -> Tensor:
return torch.load(path)[0][:, :].t().to(torch.int16) return torch.load(path)[0][:, :].t().to(torch.int16)
@cache @cache
def _get_phones(path, lang_marker="en"): def _get_phones(path, language="en"):
path = _get_phone_path(path) content = open(_get_phone_path(path), "r", encoding="utf8").read().split(" ")
with open(path, "r", encoding="utf8") as f: return ["<s>"] + [ " " if not p else p for p in split ] + ["</s>"]
content = f.read()
split = content.split(" ")
return [f"<s>"] + [ " " if not p else p for p in split ] + [f"</s>"]
def _interleaved_reorder(l, fn): def _interleaved_reorder(l, fn):
groups = defaultdict(list) groups = defaultdict(list)
@ -779,6 +776,7 @@ if __name__ == "__main__":
continue continue
print(text, task, cfg.models.prom_levels) print(text, task, cfg.models.prom_levels)
print( proms.shape, resps.shape )
decode_to_file( proms, f"./data/{task}.proms.wav", device="cpu" ) decode_to_file( proms, f"./data/{task}.proms.wav", device="cpu" )
decode_to_file( resps, f"./data/{task}.resps.wav", device="cpu" ) decode_to_file( resps, f"./data/{task}.resps.wav", device="cpu" )
break break

View File

@ -182,14 +182,14 @@ def encode_from_file(path, device="cuda"):
# trims a random piece of audio, up to `target` # trims a random piece of audio, up to `target`
def trim_random( qnt, target ): def trim_random( qnt, target ):
length = qnt.shape[0] length = max( qnt.shape[0], qnt.shape[1] )
start = int(length * random.random()) start = int(length * random.random())
end = start + target end = start + target
if end >= length: if end >= length:
start = length - target start = length - target
end = length end = length
return qnt[start:end] return qnt[start:end] if qnt.shape[0] > qnt.shape[1] else qnt[:, start:end]
# repeats the audio to fit the target size # repeats the audio to fit the target size
def repeat_extend_audio( qnt, target ): def repeat_extend_audio( qnt, target ):

View File

@ -2,7 +2,9 @@ import torch
import torchaudio import torchaudio
import soundfile import soundfile
from torch import Tensor
from einops import rearrange from einops import rearrange
from pathlib import Path
from .emb import g2p, qnt from .emb import g2p, qnt
from .emb.qnt import trim_random from .emb.qnt import trim_random
@ -29,16 +31,6 @@ class TTS():
except Exception as e: except Exception as e:
pass pass
"""
if cfg.trainer.load_state_dict:
for model in cfg.models.get():
path = cfg.ckpt_dir / model.full_name / "fp32.pth"
if model.name.startswith("ar"):
ar_ckpt = path
if model.name.startswith("nar"):
nar_ckpt = path
"""
if ar_ckpt and nar_ckpt: if ar_ckpt and nar_ckpt:
self.ar_ckpt = ar_ckpt self.ar_ckpt = ar_ckpt
self.nar_ckpt = nar_ckpt self.nar_ckpt = nar_ckpt
@ -74,22 +66,39 @@ class TTS():
elif name[:3] == "nar": elif name[:3] == "nar":
self.nar = engine.module.to(self.device) self.nar = engine.module.to(self.device)
def encode_text( self, text, lang_marker="en" ): def encode_text( self, text, language="en" ):
content = g2p.encode(text) # already a tensor, return it
if isinstance( text, Tensor ):
return text
content = g2p.encode(text, language=language)
#phones = ["<s>"] + [ " " if not p else p for p in content ] + ["</s>"] #phones = ["<s>"] + [ " " if not p else p for p in content ] + ["</s>"]
phones = [ " " if not p else p for p in content ] phones = [ " " if not p else p for p in content ]
return torch.tensor([ 1 ] + [*map(self.symmap.get, phones)] + [ 2 ]) return torch.tensor([ 1 ] + [*map(self.symmap.get, phones)] + [ 2 ])
def encode_audio( self, path, trim=True ): def encode_audio( self, paths, trim=True ):
enc = qnt.encode_from_file( path ) # already a tensor, return it
res = enc[0].t().to(torch.int16) if isinstance( paths, Tensor ):
return paths
# split string into paths
if isinstance( paths, str ):
paths = [ Path(p) for p in paths.split(";") ]
# merge inputs
res = torch.cat([qnt.encode_from_file( path )[0].t().to(torch.int16) for path in paths])
if trim: if trim:
res = trim_random( res, int( 75 * cfg.dataset.duration_range[1] ) ) res = trim_random( res, int( 75 * cfg.dataset.prompt_duration ) )
return res return res
@torch.inference_mode()
def inference( self, text, references, max_ar_steps=6 * 75, ar_temp=1.0, nar_temp=1.0, out_path=None ):
if out_path is None:
out_path = f"./data/{text}.wav"
def inference( self, text, reference, max_ar_steps=6 * 75, ar_temp=1.0, nar_temp=1.0, out_path="./.tmp.wav" ): prom = self.encode_audio( references )
prom = self.encode_audio( reference )
phns = self.encode_text( text ) phns = self.encode_text( text )
prom = to_device(prom, self.device).to(torch.int16) prom = to_device(prom, self.device).to(torch.int16)