Fix bare base64 not accept

This commit is contained in:
Sena 2022-11-23 17:43:58 +08:00 committed by GitHub
parent 828438b4a1
commit 75b67eebf2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -3,6 +3,7 @@ import io
import time import time
import uvicorn import uvicorn
from threading import Lock from threading import Lock
from io import BytesIO
from gradio.processing_utils import encode_pil_to_base64, decode_base64_to_file, decode_base64_to_image from gradio.processing_utils import encode_pil_to_base64, decode_base64_to_file, decode_base64_to_image
from fastapi import APIRouter, Depends, FastAPI, HTTPException from fastapi import APIRouter, Depends, FastAPI, HTTPException
from fastapi.security import HTTPBasic, HTTPBasicCredentials from fastapi.security import HTTPBasic, HTTPBasicCredentials
@ -13,7 +14,7 @@ from modules import sd_samplers, deepbooru
from modules.api.models import * from modules.api.models import *
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
from modules.extras import run_extras, run_pnginfo from modules.extras import run_extras, run_pnginfo
from PIL import PngImagePlugin from PIL import PngImagePlugin,Image
from modules.sd_models import checkpoints_list from modules.sd_models import checkpoints_list
from modules.realesrgan_model import get_realesrgan_models from modules.realesrgan_model import get_realesrgan_models
from typing import List from typing import List
@ -133,7 +134,10 @@ class Api:
mask = img2imgreq.mask mask = img2imgreq.mask
if mask: if mask:
mask = decode_base64_to_image(mask) if mask.startswith("data:image/"):
mask = decode_base64_to_image(mask)
else:
mask = Image.open(BytesIO(base64.b64decode(mask)))
populate = img2imgreq.copy(update={ # Override __init__ params populate = img2imgreq.copy(update={ # Override __init__ params
"sd_model": shared.sd_model, "sd_model": shared.sd_model,
@ -147,7 +151,10 @@ class Api:
imgs = [] imgs = []
for img in init_images: for img in init_images:
img = decode_base64_to_image(img) if img.startswith("data:image/"):
img = decode_base64_to_image(img)
else:
img = Image.open(BytesIO(base64.b64decode(img)))
imgs = [img] * p.batch_size imgs = [img] * p.batch_size
p.init_images = imgs p.init_images = imgs