diff --git a/modules/api/api.py b/modules/api/api.py new file mode 100644 index 00000000..5b0c934e --- /dev/null +++ b/modules/api/api.py @@ -0,0 +1,68 @@ +from modules.api.processing import StableDiffusionProcessingAPI +from modules.processing import StableDiffusionProcessingTxt2Img, process_images +from modules.sd_samplers import all_samplers +from modules.extras import run_pnginfo +import modules.shared as shared +import uvicorn +from fastapi import Body, APIRouter, HTTPException +from fastapi.responses import JSONResponse +from pydantic import BaseModel, Field, Json +import json +import io +import base64 + +sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None) + +class TextToImageResponse(BaseModel): + images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.") + parameters: Json + info: Json + + +class Api: + def __init__(self, app, queue_lock): + self.router = APIRouter() + self.app = app + self.queue_lock = queue_lock + self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"]) + + def text2imgapi(self, txt2imgreq: StableDiffusionProcessingAPI ): + sampler_index = sampler_to_index(txt2imgreq.sampler_index) + + if sampler_index is None: + raise HTTPException(status_code=404, detail="Sampler not found") + + populate = txt2imgreq.copy(update={ # Override __init__ params + "sd_model": shared.sd_model, + "sampler_index": sampler_index[0], + "do_not_save_samples": True, + "do_not_save_grid": True + } + ) + p = StableDiffusionProcessingTxt2Img(**vars(populate)) + # Override object param + with self.queue_lock: + processed = process_images(p) + + b64images = [] + for i in processed.images: + buffer = io.BytesIO() + i.save(buffer, format="png") + b64images.append(base64.b64encode(buffer.getvalue())) + + return TextToImageResponse(images=b64images, parameters=json.dumps(vars(txt2imgreq)), info=json.dumps(processed.info)) + + + + def img2imgapi(self): + raise NotImplementedError + + def extrasapi(self): + raise NotImplementedError + + def pnginfoapi(self): + raise NotImplementedError + + def launch(self, server_name, port): + self.app.include_router(self.router) + uvicorn.run(self.app, host=server_name, port=port) diff --git a/modules/api/processing.py b/modules/api/processing.py new file mode 100644 index 00000000..4c541241 --- /dev/null +++ b/modules/api/processing.py @@ -0,0 +1,99 @@ +from inflection import underscore +from typing import Any, Dict, Optional +from pydantic import BaseModel, Field, create_model +from modules.processing import StableDiffusionProcessingTxt2Img +import inspect + + +API_NOT_ALLOWED = [ + "self", + "kwargs", + "sd_model", + "outpath_samples", + "outpath_grids", + "sampler_index", + "do_not_save_samples", + "do_not_save_grid", + "extra_generation_params", + "overlay_images", + "do_not_reload_embeddings", + "seed_enable_extras", + "prompt_for_display", + "sampler_noise_scheduler_override", + "ddim_discretize" +] + +class ModelDef(BaseModel): + """Assistance Class for Pydantic Dynamic Model Generation""" + + field: str + field_alias: str + field_type: Any + field_value: Any + + +class PydanticModelGenerator: + """ + Takes in created classes and stubs them out in a way FastAPI/Pydantic is happy about: + source_data is a snapshot of the default values produced by the class + params are the names of the actual keys required by __init__ + """ + + def __init__( + self, + model_name: str = None, + class_instance = None, + additional_fields = None, + ): + def field_type_generator(k, v): + # field_type = str if not overrides.get(k) else overrides[k]["type"] + # print(k, v.annotation, v.default) + field_type = v.annotation + + return Optional[field_type] + + def merge_class_params(class_): + all_classes = list(filter(lambda x: x is not object, inspect.getmro(class_))) + parameters = {} + for classes in all_classes: + parameters = {**parameters, **inspect.signature(classes.__init__).parameters} + return parameters + + + self._model_name = model_name + self._class_data = merge_class_params(class_instance) + self._model_def = [ + ModelDef( + field=underscore(k), + field_alias=k, + field_type=field_type_generator(k, v), + field_value=v.default + ) + for (k,v) in self._class_data.items() if k not in API_NOT_ALLOWED + ] + + for fields in additional_fields: + self._model_def.append(ModelDef( + field=underscore(fields["key"]), + field_alias=fields["key"], + field_type=fields["type"], + field_value=fields["default"])) + + def generate_model(self): + """ + Creates a pydantic BaseModel + from the json and overrides provided at initialization + """ + fields = { + d.field: (d.field_type, Field(default=d.field_value, alias=d.field_alias)) for d in self._model_def + } + DynamicModel = create_model(self._model_name, **fields) + DynamicModel.__config__.allow_population_by_field_name = True + DynamicModel.__config__.allow_mutation = True + return DynamicModel + +StableDiffusionProcessingAPI = PydanticModelGenerator( + "StableDiffusionProcessingTxt2Img", + StableDiffusionProcessingTxt2Img, + [{"key": "sampler_index", "type": str, "default": "Euler"}] +).generate_model() \ No newline at end of file diff --git a/modules/processing.py b/modules/processing.py index 346eea88..ea926fc3 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -9,6 +9,7 @@ from PIL import Image, ImageFilter, ImageOps import random import cv2 from skimage import exposure +from typing import Any, Dict, List, Optional import modules.sd_hijack from modules import devices, prompt_parser, masking, sd_samplers, lowvram @@ -51,9 +52,15 @@ def get_correct_sampler(p): return sd_samplers.samplers elif isinstance(p, modules.processing.StableDiffusionProcessingImg2Img): return sd_samplers.samplers_for_img2img + elif isinstance(p, modules.api.processing.StableDiffusionProcessingAPI): + return sd_samplers.samplers -class StableDiffusionProcessing: - def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", styles=None, seed=-1, subseed=-1, subseed_strength=0, seed_resize_from_h=-1, seed_resize_from_w=-1, seed_enable_extras=True, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, restore_faces=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None, eta=None, do_not_reload_embeddings=False): +class StableDiffusionProcessing(): + """ + The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing + + """ + def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str="", styles: List[str]=None, seed: int=-1, subseed: int=-1, subseed_strength: float=0, seed_resize_from_h: int=-1, seed_resize_from_w: int=-1, seed_enable_extras: bool=True, sampler_index: int=0, batch_size: int=1, n_iter: int=1, steps:int =50, cfg_scale:float=7.0, width:int=512, height:int=512, restore_faces:bool=False, tiling:bool=False, do_not_save_samples:bool=False, do_not_save_grid:bool=False, extra_generation_params: Dict[Any,Any]=None, overlay_images: Any=None, negative_prompt: str=None, eta: float =None, do_not_reload_embeddings: bool=False, denoising_strength: float = 0, ddim_discretize: str = "uniform", s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0): self.sd_model = sd_model self.outpath_samples: str = outpath_samples self.outpath_grids: str = outpath_grids @@ -86,10 +93,10 @@ class StableDiffusionProcessing: self.denoising_strength: float = 0 self.sampler_noise_scheduler_override = None self.ddim_discretize = opts.ddim_discretize - self.s_churn = opts.s_churn - self.s_tmin = opts.s_tmin - self.s_tmax = float('inf') # not representable as a standard ui option - self.s_noise = opts.s_noise + self.s_churn = s_churn or opts.s_churn + self.s_tmin = s_tmin or opts.s_tmin + self.s_tmax = s_tmax or float('inf') # not representable as a standard ui option + self.s_noise = s_noise or opts.s_noise if not seed_enable_extras: self.subseed = -1 @@ -97,6 +104,7 @@ class StableDiffusionProcessing: self.seed_resize_from_h = 0 self.seed_resize_from_w = 0 + def init(self, all_prompts, all_seeds, all_subseeds): pass @@ -491,7 +499,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed: class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): sampler = None - def __init__(self, enable_hr=False, denoising_strength=0.75, firstphase_width=0, firstphase_height=0, **kwargs): + def __init__(self, enable_hr: bool=False, denoising_strength: float=0.75, firstphase_width: int=0, firstphase_height: int=0, **kwargs): super().__init__(**kwargs) self.enable_hr = enable_hr self.denoising_strength = denoising_strength @@ -717,4 +725,4 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): del x devices.torch_gc() - return samples + return samples \ No newline at end of file diff --git a/modules/sd_models.py b/modules/sd_models.py index 3aa21ec1..7ad6d474 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -122,11 +122,33 @@ def select_checkpoint(): return checkpoint_info +chckpoint_dict_replacements = { + 'cond_stage_model.transformer.embeddings.': 'cond_stage_model.transformer.text_model.embeddings.', + 'cond_stage_model.transformer.encoder.': 'cond_stage_model.transformer.text_model.encoder.', + 'cond_stage_model.transformer.final_layer_norm.': 'cond_stage_model.transformer.text_model.final_layer_norm.', +} + + +def transform_checkpoint_dict_key(k): + for text, replacement in chckpoint_dict_replacements.items(): + if k.startswith(text): + k = replacement + k[len(text):] + + return k + + def get_state_dict_from_checkpoint(pl_sd): if "state_dict" in pl_sd: - return pl_sd["state_dict"] + pl_sd = pl_sd["state_dict"] - return pl_sd + sd = {} + for k, v in pl_sd.items(): + new_key = transform_checkpoint_dict_key(k) + + if new_key is not None: + sd[new_key] = v + + return sd def load_model_weights(model, checkpoint_info): @@ -141,7 +163,7 @@ def load_model_weights(model, checkpoint_info): print(f"Global Step: {pl_sd['global_step']}") sd = get_state_dict_from_checkpoint(pl_sd) - model.load_state_dict(sd, strict=False) + missing, extra = model.load_state_dict(sd, strict=False) if shared.cmd_opts.opt_channelslast: model.to(memory_format=torch.channels_last) diff --git a/modules/shared.py b/modules/shared.py index 6b6d5c41..0540cae9 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -76,6 +76,8 @@ parser.add_argument("--disable-console-progressbars", action='store_true', help= parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False) parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencoders model', default=None) parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False) +parser.add_argument("--api", action='store_true', help="use api=True to launch the api with the webui") +parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the api instead of the webui") cmd_opts = parser.parse_args() restricted_opts = [ diff --git a/requirements.txt b/requirements.txt index cf583de9..da1969cf 100644 --- a/requirements.txt +++ b/requirements.txt @@ -23,3 +23,4 @@ resize-right torchdiffeq kornia lark +inflection diff --git a/requirements_versions.txt b/requirements_versions.txt index abadcb58..72ccc5a3 100644 --- a/requirements_versions.txt +++ b/requirements_versions.txt @@ -22,3 +22,4 @@ resize-right==0.0.2 torchdiffeq==0.2.3 kornia==0.6.7 lark==1.1.2 +inflection==0.5.1 diff --git a/webui.py b/webui.py index 86e98ad0..c7260c7a 100644 --- a/webui.py +++ b/webui.py @@ -4,7 +4,7 @@ import time import importlib import signal import threading - +from fastapi import FastAPI from fastapi.middleware.gzip import GZipMiddleware from modules.paths import script_path @@ -31,7 +31,6 @@ from modules.paths import script_path from modules.shared import cmd_opts import modules.hypernetworks.hypernetwork - queue_lock = threading.Lock() @@ -88,11 +87,7 @@ def initialize(): shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength) shared.opts.onchange("sd_hypernetwork_layer_structure", modules.hypernetworks.hypernetwork.apply_layer_structure) shared.opts.onchange("sd_hypernetwork_add_layer_norm", modules.hypernetworks.hypernetwork.apply_layer_norm) - - -def webui(): - initialize() - + # make the program just exit at ctrl+c without waiting for anything def sigint_handler(sig, frame): print(f'Interrupted with signal {sig} in {frame}') @@ -100,8 +95,35 @@ def webui(): signal.signal(signal.SIGINT, sigint_handler) - while 1: +def create_api(app): + from modules.api.api import Api + api = Api(app, queue_lock) + return api + +def wait_on_server(demo=None): + while 1: + time.sleep(0.5) + if demo and getattr(demo, 'do_restart', False): + time.sleep(0.5) + demo.close() + time.sleep(0.5) + break + +def api_only(): + initialize() + + app = FastAPI() + app.add_middleware(GZipMiddleware, minimum_size=1000) + api = create_api(app) + + api.launch(server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1", port=cmd_opts.port if cmd_opts.port else 7861) + + +def webui(launch_api=False): + initialize() + + while 1: demo = modules.ui.create_ui(wrap_gradio_gpu_call=wrap_gradio_gpu_call) app, local_url, share_url = demo.launch( @@ -113,17 +135,14 @@ def webui(): inbrowser=cmd_opts.autolaunch, prevent_thread_lock=True ) - + app.add_middleware(GZipMiddleware, minimum_size=1000) - while 1: - time.sleep(0.5) - if getattr(demo, 'do_restart', False): - time.sleep(0.5) - demo.close() - time.sleep(0.5) - break + if (launch_api): + create_api(app) + wait_on_server(demo) + sd_samplers.set_samplers() print('Reloading Custom Scripts') @@ -135,5 +154,10 @@ def webui(): print('Restarting Gradio') + +task = [] if __name__ == "__main__": - webui() + if cmd_opts.nowebui: + api_only() + else: + webui(cmd_opts.api) \ No newline at end of file