vall-e/vall_e/__main__.py

28 lines
1.1 KiB
Python
Raw Normal View History

2023-08-02 21:53:35 +00:00
import argparse
from pathlib import Path
from .inference import TTS
2023-08-21 02:36:02 +00:00
def path_list(arg):
return [Path(p) for p in arg.split(";")]
2023-08-02 21:53:35 +00:00
def main():
parser = argparse.ArgumentParser("VALL-E TTS")
parser.add_argument("text")
2023-08-21 02:36:02 +00:00
parser.add_argument("references", type=path_list)
parser.add_argument("--out-path", type=Path, default=None)
2023-08-02 21:53:35 +00:00
parser.add_argument("--yaml", type=Path, default=None)
parser.add_argument("--ar-ckpt", type=Path, default=None)
parser.add_argument("--nar-ckpt", type=Path, default=None)
parser.add_argument("--max-ar-steps", type=int, default=6 * 75)
parser.add_argument("--ar-temp", type=float, default=1.0)
parser.add_argument("--nar-temp", type=float, default=1.0)
parser.add_argument("--device", default="cuda")
args = parser.parse_args()
tts = TTS( config=args.yaml, ar_ckpt=args.ar_ckpt, nar_ckpt=args.nar_ckpt, device=args.device )
2023-08-21 02:36:02 +00:00
tts.inference( text=args.text, references=args.references, out_path=args.out_path, max_ar_steps=args.max_ar_steps, ar_temp=args.ar_temp, nar_temp=args.nar_temp )
2023-08-02 21:53:35 +00:00
if __name__ == "__main__":
main()