From 7f8bd2b936ef1820ceee877bffa2c9540f297d77 Mon Sep 17 00:00:00 2001 From: mrq Date: Sat, 9 Sep 2023 20:05:03 -0500 Subject: [PATCH] added printing elasped inference time --- vall_e/webui.py | 38 ++++++++++++++++++++++++-------------- 1 file changed, 24 insertions(+), 14 deletions(-) diff --git a/vall_e/webui.py b/vall_e/webui.py index a2b7b79..a7cff16 100644 --- a/vall_e/webui.py +++ b/vall_e/webui.py @@ -7,6 +7,7 @@ import functools import gradio as gr +from time import perf_counter from pathlib import Path from .inference import TTS @@ -32,6 +33,14 @@ def gradio_wrapper(inputs): return wrapped_function return decorated +class timer: + def __enter__(self): + self.start = perf_counter() + return self + + def __exit__(self, type, value, traceback): + print(f'Elapsed time: {(perf_counter() - self.start):.3f}s') + def init_tts(restart=False): global tts @@ -71,20 +80,21 @@ def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ): tmp = tempfile.NamedTemporaryFile(suffix='.wav') tts = init_tts() - wav, sr = tts.inference( - text=args.text, - references=[args.references.split(";")], - out_path=tmp.name, - max_ar_steps=args.max_ar_steps, - input_prompt_length=args.input_prompt_length, - ar_temp=args.ar_temp, - nar_temp=args.nar_temp, - top_p=args.top_p, - top_k=args.top_k, - repetition_penalty=args.repetition_penalty, - repetition_penalty_decay=args.repetition_penalty_decay, - length_penalty=args.length_penalty - ) + with timer() as t: + wav, sr = tts.inference( + text=args.text, + references=[args.references.split(";")], + out_path=tmp.name, + max_ar_steps=args.max_ar_steps, + input_prompt_length=args.input_prompt_length, + ar_temp=args.ar_temp, + nar_temp=args.nar_temp, + top_p=args.top_p, + top_k=args.top_k, + repetition_penalty=args.repetition_penalty, + repetition_penalty_decay=args.repetition_penalty_decay, + length_penalty=args.length_penalty + ) wav = wav.squeeze(0).cpu().numpy() return (sr, wav)