From 54dc6f930761a6a471206b2ba74f1cc73c54789f Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Thu, 1 Sep 2022 19:09:48 +0300 Subject: [PATCH] actual support for share=True in gradio --- webui.py | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/webui.py b/webui.py index 76edacbc..e5a21b2f 100644 --- a/webui.py +++ b/webui.py @@ -37,6 +37,7 @@ from contextlib import nullcontext import signal import tqdm import re +import threading import k_diffusion.sampling from ldm.util import instantiate_from_config @@ -75,6 +76,7 @@ cpu = torch.device("cpu") gpu = torch.device("cuda") device = gpu if torch.cuda.is_available() else cpu batch_cond_uncond = not (cmd_opts.lowvram or cmd_opts.medvram) +queue_lock = threading.Lock() if not cmd_opts.share: # fix gradio phoning home @@ -643,10 +645,20 @@ def resize_image(resize_mode, im, width, height): return res +def wrap_gradio_gpu_call(func): + def f(*args, **kwargs): + with queue_lock: + res = func(*args, **kwargs) + + return res + + return f + + def wrap_gradio_call(func): - def f(*p1, **p2): + def f(*args, **kwargs): t = time.perf_counter() - res = list(func(*p1, **p2)) + res = list(func(*args, **kwargs)) elapsed = time.perf_counter() - t # last item is always HTML @@ -1259,7 +1271,7 @@ with gr.Blocks(analytics_enabled=False) as txt2img_interface: html_info = gr.HTML() txt2img_args = dict( - fn=wrap_gradio_call(txt2img), + fn=wrap_gradio_gpu_call(txt2img), inputs=[ prompt, negative_prompt, @@ -1657,7 +1669,7 @@ with gr.Blocks(analytics_enabled=False) as img2img_interface: ) img2img_args = dict( - fn=wrap_gradio_call(img2img), + fn=wrap_gradio_gpu_call(img2img), inputs=[ prompt, init_img, @@ -1736,7 +1748,7 @@ def run_extras(image, GFPGAN_strength, RealESRGAN_upscaling, RealESRGAN_model_in extras_interface = gr.Interface( - wrap_gradio_call(run_extras), + wrap_gradio_gpu_call(run_extras), inputs=[ gr.Image(label="Source", source="upload", interactive=True, type="pil"), gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="GFPGAN strength", value=1, interactive=have_gfpgan), @@ -1904,6 +1916,5 @@ def inject_gradio_html(javascript): inject_gradio_html(javascript) -demo.queue(concurrency_count=1) demo.launch(share=cmd_opts.share)