added printing elasped inference time

This commit is contained in:
mrq 2023-09-09 20:05:03 -05:00
parent 4f61f5c889
commit 7f8bd2b936

View File

@ -7,6 +7,7 @@ import functools
import gradio as gr import gradio as gr
from time import perf_counter
from pathlib import Path from pathlib import Path
from .inference import TTS from .inference import TTS
@ -32,6 +33,14 @@ def gradio_wrapper(inputs):
return wrapped_function return wrapped_function
return decorated return decorated
class timer:
def __enter__(self):
self.start = perf_counter()
return self
def __exit__(self, type, value, traceback):
print(f'Elapsed time: {(perf_counter() - self.start):.3f}s')
def init_tts(restart=False): def init_tts(restart=False):
global tts global tts
@ -71,20 +80,21 @@ def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
tmp = tempfile.NamedTemporaryFile(suffix='.wav') tmp = tempfile.NamedTemporaryFile(suffix='.wav')
tts = init_tts() tts = init_tts()
wav, sr = tts.inference( with timer() as t:
text=args.text, wav, sr = tts.inference(
references=[args.references.split(";")], text=args.text,
out_path=tmp.name, references=[args.references.split(";")],
max_ar_steps=args.max_ar_steps, out_path=tmp.name,
input_prompt_length=args.input_prompt_length, max_ar_steps=args.max_ar_steps,
ar_temp=args.ar_temp, input_prompt_length=args.input_prompt_length,
nar_temp=args.nar_temp, ar_temp=args.ar_temp,
top_p=args.top_p, nar_temp=args.nar_temp,
top_k=args.top_k, top_p=args.top_p,
repetition_penalty=args.repetition_penalty, top_k=args.top_k,
repetition_penalty_decay=args.repetition_penalty_decay, repetition_penalty=args.repetition_penalty,
length_penalty=args.length_penalty repetition_penalty_decay=args.repetition_penalty_decay,
) length_penalty=args.length_penalty
)
wav = wav.squeeze(0).cpu().numpy() wav = wav.squeeze(0).cpu().numpy()
return (sr, wav) return (sr, wav)