Merge pull request #4294 from evshiron/feat/allow-origins

add --cors-allow-origins cmd opt
This commit is contained in:
AUTOMATIC1111 2022-11-05 16:20:45 +03:00 committed by GitHub
commit 994136b9c2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 13 additions and 3 deletions

View File

@ -86,6 +86,7 @@ parser.add_argument("--nowebui", action='store_true', help="use api=True to laun
parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load model to quickly launch UI") parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load model to quickly launch UI")
parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None) parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None)
parser.add_argument("--administrator", action='store_true', help="Administrator rights", default=False) parser.add_argument("--administrator", action='store_true', help="Administrator rights", default=False)
parser.add_argument("--cors-allow-origins", type=str, help="Allowed CORS origins", default=None)
parser.add_argument("--tls-keyfile", type=str, help="Partially enables TLS, requires --tls-certfile to fully function", default=None) parser.add_argument("--tls-keyfile", type=str, help="Partially enables TLS, requires --tls-certfile to fully function", default=None)
parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, requires --tls-keyfile to fully function", default=None) parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, requires --tls-keyfile to fully function", default=None)
parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None) parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None)
@ -150,9 +151,9 @@ class State:
self.interrupted = True self.interrupted = True
def nextjob(self): def nextjob(self):
if opts.show_progress_every_n_steps == -1: if opts.show_progress_every_n_steps == -1:
self.do_set_current_image() self.do_set_current_image()
self.job_no += 1 self.job_no += 1
self.sampling_step = 0 self.sampling_step = 0
self.current_image_sampling_step = 0 self.current_image_sampling_step = 0
@ -201,7 +202,7 @@ class State:
return return
if self.current_latent is None: if self.current_latent is None:
return return
if opts.show_progress_grid: if opts.show_progress_grid:
self.current_image = sd_samplers.samples_to_image_grid(self.current_latent) self.current_image = sd_samplers.samples_to_image_grid(self.current_latent)
else: else:

View File

@ -5,6 +5,7 @@ import importlib
import signal import signal
import threading import threading
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware from fastapi.middleware.gzip import GZipMiddleware
from modules.paths import script_path from modules.paths import script_path
@ -107,6 +108,11 @@ def initialize():
signal.signal(signal.SIGINT, sigint_handler) signal.signal(signal.SIGINT, sigint_handler)
def setup_cors(app):
if cmd_opts.cors_allow_origins:
app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_methods=['*'])
def create_api(app): def create_api(app):
from modules.api.api import Api from modules.api.api import Api
api = Api(app, queue_lock) api = Api(app, queue_lock)
@ -128,6 +134,7 @@ def api_only():
initialize() initialize()
app = FastAPI() app = FastAPI()
setup_cors(app)
app.add_middleware(GZipMiddleware, minimum_size=1000) app.add_middleware(GZipMiddleware, minimum_size=1000)
api = create_api(app) api = create_api(app)
@ -163,6 +170,8 @@ def webui():
# runnnig its code. We disable this here. Suggested by RyotaK. # runnnig its code. We disable this here. Suggested by RyotaK.
app.user_middleware = [x for x in app.user_middleware if x.cls.__name__ != 'CORSMiddleware'] app.user_middleware = [x for x in app.user_middleware if x.cls.__name__ != 'CORSMiddleware']
setup_cors(app)
app.add_middleware(GZipMiddleware, minimum_size=1000) app.add_middleware(GZipMiddleware, minimum_size=1000)
if launch_api: if launch_api: