diff --git a/modules/api/api.py b/modules/api/api.py index 3df6ff96..3caa83a4 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -33,6 +33,14 @@ class Api: self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"]) self.app.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"]) + def __base64_to_image(self, base64_string): + # if has a comma, deal with prefix + if "," in base64_string: + base64_string = base64_string.split(",")[1] + imgdata = base64.b64decode(base64_string) + # convert base64 to PIL image + return Image.open(io.BytesIO(imgdata)) + def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI): sampler_index = sampler_to_index(txt2imgreq.sampler_index) @@ -74,26 +82,22 @@ class Api: mask = img2imgreq.mask if mask: - raise HTTPException(status_code=400, detail="Mask not supported yet") + mask = self.__base64_to_image(mask) populate = img2imgreq.copy(update={ # Override __init__ params "sd_model": shared.sd_model, "sampler_index": sampler_index[0], "do_not_save_samples": True, - "do_not_save_grid": True + "do_not_save_grid": True, + "mask": mask } ) p = StableDiffusionProcessingImg2Img(**vars(populate)) imgs = [] for img in init_images: - # if has a comma, deal with prefix - if "," in img: - img = img.split(",")[1] - # convert base64 to PIL image - img = base64.b64decode(img) - img = Image.open(io.BytesIO(img)) + img = self.__base64_to_image(img) imgs = [img] * p.batch_size p.init_images = imgs