added printing elasped inference time
This commit is contained in:
parent
4f61f5c889
commit
7f8bd2b936
|
@ -7,6 +7,7 @@ import functools
|
|||
|
||||
import gradio as gr
|
||||
|
||||
from time import perf_counter
|
||||
from pathlib import Path
|
||||
|
||||
from .inference import TTS
|
||||
|
@ -32,6 +33,14 @@ def gradio_wrapper(inputs):
|
|||
return wrapped_function
|
||||
return decorated
|
||||
|
||||
class timer:
|
||||
def __enter__(self):
|
||||
self.start = perf_counter()
|
||||
return self
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
print(f'Elapsed time: {(perf_counter() - self.start):.3f}s')
|
||||
|
||||
def init_tts(restart=False):
|
||||
global tts
|
||||
|
||||
|
@ -71,20 +80,21 @@ def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
|||
tmp = tempfile.NamedTemporaryFile(suffix='.wav')
|
||||
|
||||
tts = init_tts()
|
||||
wav, sr = tts.inference(
|
||||
text=args.text,
|
||||
references=[args.references.split(";")],
|
||||
out_path=tmp.name,
|
||||
max_ar_steps=args.max_ar_steps,
|
||||
input_prompt_length=args.input_prompt_length,
|
||||
ar_temp=args.ar_temp,
|
||||
nar_temp=args.nar_temp,
|
||||
top_p=args.top_p,
|
||||
top_k=args.top_k,
|
||||
repetition_penalty=args.repetition_penalty,
|
||||
repetition_penalty_decay=args.repetition_penalty_decay,
|
||||
length_penalty=args.length_penalty
|
||||
)
|
||||
with timer() as t:
|
||||
wav, sr = tts.inference(
|
||||
text=args.text,
|
||||
references=[args.references.split(";")],
|
||||
out_path=tmp.name,
|
||||
max_ar_steps=args.max_ar_steps,
|
||||
input_prompt_length=args.input_prompt_length,
|
||||
ar_temp=args.ar_temp,
|
||||
nar_temp=args.nar_temp,
|
||||
top_p=args.top_p,
|
||||
top_k=args.top_k,
|
||||
repetition_penalty=args.repetition_penalty,
|
||||
repetition_penalty_decay=args.repetition_penalty_decay,
|
||||
length_penalty=args.length_penalty
|
||||
)
|
||||
|
||||
wav = wav.squeeze(0).cpu().numpy()
|
||||
return (sr, wav)
|
||||
|
|
Loading…
Reference in New Issue
Block a user