gradio and FastAPI
This commit is contained in:
parent
1df3ff25e6
commit
8d5d863a9d
|
@ -16,9 +16,11 @@ class TextToImageResponse(BaseModel):
|
|||
|
||||
|
||||
class Api:
|
||||
def __init__(self, app):
|
||||
def __init__(self, app, queue_lock):
|
||||
self.router = APIRouter()
|
||||
app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"])
|
||||
self.app = app
|
||||
self.queue_lock = queue_lock
|
||||
self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"])
|
||||
|
||||
def text2imgapi(self, txt2imgreq: StableDiffusionProcessingAPI ):
|
||||
populate = txt2imgreq.copy(update={ # Override __init__ params
|
||||
|
@ -30,7 +32,8 @@ class Api:
|
|||
)
|
||||
p = StableDiffusionProcessingTxt2Img(**vars(populate))
|
||||
# Override object param
|
||||
processed = process_images(p)
|
||||
with self.queue_lock:
|
||||
processed = process_images(p)
|
||||
|
||||
b64images = []
|
||||
for i in processed.images:
|
||||
|
@ -52,5 +55,5 @@ class Api:
|
|||
raise NotImplementedError
|
||||
|
||||
def launch(self, server_name, port):
|
||||
app.include_router(self.router)
|
||||
uvicorn.run(app, host=server_name, port=port)
|
||||
self.app.include_router(self.router)
|
||||
uvicorn.run(self.app, host=server_name, port=port)
|
||||
|
|
18
webui.py
18
webui.py
|
@ -4,7 +4,7 @@ import time
|
|||
import importlib
|
||||
import signal
|
||||
import threading
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.gzip import GZipMiddleware
|
||||
|
||||
from modules.paths import script_path
|
||||
|
@ -31,7 +31,6 @@ from modules.paths import script_path
|
|||
from modules.shared import cmd_opts
|
||||
import modules.hypernetworks.hypernetwork
|
||||
|
||||
|
||||
queue_lock = threading.Lock()
|
||||
|
||||
|
||||
|
@ -97,7 +96,7 @@ def initialize():
|
|||
|
||||
def create_api(app):
|
||||
from modules.api.api import Api
|
||||
api = Api(app)
|
||||
api = Api(app, queue_lock)
|
||||
return api
|
||||
|
||||
def wait_on_server(demo=None):
|
||||
|
@ -141,7 +140,7 @@ def webui(launch_api=False):
|
|||
create_api(app)
|
||||
|
||||
wait_on_server(demo)
|
||||
|
||||
|
||||
sd_samplers.set_samplers()
|
||||
|
||||
print('Reloading Custom Scripts')
|
||||
|
@ -153,11 +152,10 @@ def webui(launch_api=False):
|
|||
print('Restarting Gradio')
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if not cmd_opts.nowebui:
|
||||
api_only()
|
||||
|
||||
if cmd_opts.api:
|
||||
webui(True)
|
||||
task = []
|
||||
if __name__ == "__main__":
|
||||
if cmd_opts.nowebui:
|
||||
api_only()
|
||||
else:
|
||||
webui(False)
|
||||
webui(cmd_opts.api)
|
Loading…
Reference in New Issue
Block a user