actual support for share=True in gradio

This commit is contained in:
AUTOMATIC 2022-09-01 19:09:48 +03:00
parent f1aa1d6711
commit 54dc6f9307

View File

@ -37,6 +37,7 @@ from contextlib import nullcontext
import signal import signal
import tqdm import tqdm
import re import re
import threading
import k_diffusion.sampling import k_diffusion.sampling
from ldm.util import instantiate_from_config from ldm.util import instantiate_from_config
@ -75,6 +76,7 @@ cpu = torch.device("cpu")
gpu = torch.device("cuda") gpu = torch.device("cuda")
device = gpu if torch.cuda.is_available() else cpu device = gpu if torch.cuda.is_available() else cpu
batch_cond_uncond = not (cmd_opts.lowvram or cmd_opts.medvram) batch_cond_uncond = not (cmd_opts.lowvram or cmd_opts.medvram)
queue_lock = threading.Lock()
if not cmd_opts.share: if not cmd_opts.share:
# fix gradio phoning home # fix gradio phoning home
@ -643,10 +645,20 @@ def resize_image(resize_mode, im, width, height):
return res return res
def wrap_gradio_gpu_call(func):
def f(*args, **kwargs):
with queue_lock:
res = func(*args, **kwargs)
return res
return f
def wrap_gradio_call(func): def wrap_gradio_call(func):
def f(*p1, **p2): def f(*args, **kwargs):
t = time.perf_counter() t = time.perf_counter()
res = list(func(*p1, **p2)) res = list(func(*args, **kwargs))
elapsed = time.perf_counter() - t elapsed = time.perf_counter() - t
# last item is always HTML # last item is always HTML
@ -1259,7 +1271,7 @@ with gr.Blocks(analytics_enabled=False) as txt2img_interface:
html_info = gr.HTML() html_info = gr.HTML()
txt2img_args = dict( txt2img_args = dict(
fn=wrap_gradio_call(txt2img), fn=wrap_gradio_gpu_call(txt2img),
inputs=[ inputs=[
prompt, prompt,
negative_prompt, negative_prompt,
@ -1657,7 +1669,7 @@ with gr.Blocks(analytics_enabled=False) as img2img_interface:
) )
img2img_args = dict( img2img_args = dict(
fn=wrap_gradio_call(img2img), fn=wrap_gradio_gpu_call(img2img),
inputs=[ inputs=[
prompt, prompt,
init_img, init_img,
@ -1736,7 +1748,7 @@ def run_extras(image, GFPGAN_strength, RealESRGAN_upscaling, RealESRGAN_model_in
extras_interface = gr.Interface( extras_interface = gr.Interface(
wrap_gradio_call(run_extras), wrap_gradio_gpu_call(run_extras),
inputs=[ inputs=[
gr.Image(label="Source", source="upload", interactive=True, type="pil"), gr.Image(label="Source", source="upload", interactive=True, type="pil"),
gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="GFPGAN strength", value=1, interactive=have_gfpgan), gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="GFPGAN strength", value=1, interactive=have_gfpgan),
@ -1904,6 +1916,5 @@ def inject_gradio_html(javascript):
inject_gradio_html(javascript) inject_gradio_html(javascript)
demo.queue(concurrency_count=1)
demo.launch(share=cmd_opts.share) demo.launch(share=cmd_opts.share)