gradio and FastAPI
This commit is contained in:
parent
1df3ff25e6
commit
8d5d863a9d
|
@ -16,9 +16,11 @@ class TextToImageResponse(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class Api:
|
class Api:
|
||||||
def __init__(self, app):
|
def __init__(self, app, queue_lock):
|
||||||
self.router = APIRouter()
|
self.router = APIRouter()
|
||||||
app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"])
|
self.app = app
|
||||||
|
self.queue_lock = queue_lock
|
||||||
|
self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"])
|
||||||
|
|
||||||
def text2imgapi(self, txt2imgreq: StableDiffusionProcessingAPI ):
|
def text2imgapi(self, txt2imgreq: StableDiffusionProcessingAPI ):
|
||||||
populate = txt2imgreq.copy(update={ # Override __init__ params
|
populate = txt2imgreq.copy(update={ # Override __init__ params
|
||||||
|
@ -30,7 +32,8 @@ class Api:
|
||||||
)
|
)
|
||||||
p = StableDiffusionProcessingTxt2Img(**vars(populate))
|
p = StableDiffusionProcessingTxt2Img(**vars(populate))
|
||||||
# Override object param
|
# Override object param
|
||||||
processed = process_images(p)
|
with self.queue_lock:
|
||||||
|
processed = process_images(p)
|
||||||
|
|
||||||
b64images = []
|
b64images = []
|
||||||
for i in processed.images:
|
for i in processed.images:
|
||||||
|
@ -52,5 +55,5 @@ class Api:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def launch(self, server_name, port):
|
def launch(self, server_name, port):
|
||||||
app.include_router(self.router)
|
self.app.include_router(self.router)
|
||||||
uvicorn.run(app, host=server_name, port=port)
|
uvicorn.run(self.app, host=server_name, port=port)
|
||||||
|
|
18
webui.py
18
webui.py
|
@ -4,7 +4,7 @@ import time
|
||||||
import importlib
|
import importlib
|
||||||
import signal
|
import signal
|
||||||
import threading
|
import threading
|
||||||
|
from fastapi import FastAPI
|
||||||
from fastapi.middleware.gzip import GZipMiddleware
|
from fastapi.middleware.gzip import GZipMiddleware
|
||||||
|
|
||||||
from modules.paths import script_path
|
from modules.paths import script_path
|
||||||
|
@ -31,7 +31,6 @@ from modules.paths import script_path
|
||||||
from modules.shared import cmd_opts
|
from modules.shared import cmd_opts
|
||||||
import modules.hypernetworks.hypernetwork
|
import modules.hypernetworks.hypernetwork
|
||||||
|
|
||||||
|
|
||||||
queue_lock = threading.Lock()
|
queue_lock = threading.Lock()
|
||||||
|
|
||||||
|
|
||||||
|
@ -97,7 +96,7 @@ def initialize():
|
||||||
|
|
||||||
def create_api(app):
|
def create_api(app):
|
||||||
from modules.api.api import Api
|
from modules.api.api import Api
|
||||||
api = Api(app)
|
api = Api(app, queue_lock)
|
||||||
return api
|
return api
|
||||||
|
|
||||||
def wait_on_server(demo=None):
|
def wait_on_server(demo=None):
|
||||||
|
@ -141,7 +140,7 @@ def webui(launch_api=False):
|
||||||
create_api(app)
|
create_api(app)
|
||||||
|
|
||||||
wait_on_server(demo)
|
wait_on_server(demo)
|
||||||
|
|
||||||
sd_samplers.set_samplers()
|
sd_samplers.set_samplers()
|
||||||
|
|
||||||
print('Reloading Custom Scripts')
|
print('Reloading Custom Scripts')
|
||||||
|
@ -153,11 +152,10 @@ def webui(launch_api=False):
|
||||||
print('Restarting Gradio')
|
print('Restarting Gradio')
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
if not cmd_opts.nowebui:
|
|
||||||
api_only()
|
|
||||||
|
|
||||||
if cmd_opts.api:
|
task = []
|
||||||
webui(True)
|
if __name__ == "__main__":
|
||||||
|
if cmd_opts.nowebui:
|
||||||
|
api_only()
|
||||||
else:
|
else:
|
||||||
webui(False)
|
webui(cmd_opts.api)
|
Loading…
Reference in New Issue
Block a user