call sampler by name
This commit is contained in:
parent
e7f4808505
commit
0f0d6ab8e0
|
@ -1,6 +1,7 @@
|
||||||
from modules.api.processing import StableDiffusionProcessingAPI
|
from modules.api.processing import StableDiffusionProcessingAPI
|
||||||
from modules.processing import StableDiffusionProcessingTxt2Img, process_images
|
from modules.processing import StableDiffusionProcessingTxt2Img, process_images
|
||||||
from modules.sd_samplers import samplers_k_diffusion
|
from modules.sd_samplers import all_samplers
|
||||||
|
from modules.extras import run_pnginfo
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from fastapi import Body, APIRouter, HTTPException
|
from fastapi import Body, APIRouter, HTTPException
|
||||||
|
@ -10,7 +11,7 @@ import json
|
||||||
import io
|
import io
|
||||||
import base64
|
import base64
|
||||||
|
|
||||||
sampler_to_index = lambda name: next(filter(lambda row: name in row[1][2], enumerate(samplers_k_diffusion)), None)
|
sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None)
|
||||||
|
|
||||||
class TextToImageResponse(BaseModel):
|
class TextToImageResponse(BaseModel):
|
||||||
images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
|
images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
|
||||||
|
@ -53,13 +54,13 @@ class Api:
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def img2imgendoint(self):
|
def img2imgapi(self):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def extrasendoint(self):
|
def extrasapi(self):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def pnginfoendoint(self):
|
def pnginfoapi(self):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def launch(self, server_name, port):
|
def launch(self, server_name, port):
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
from inflection import underscore
|
from inflection import underscore
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
from pydantic import BaseModel, Field, create_model
|
from pydantic import BaseModel, Field, create_model
|
||||||
from modules.processing import StableDiffusionProcessing, Processed, StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
|
from modules.processing import StableDiffusionProcessingTxt2Img
|
||||||
import inspect
|
import inspect
|
||||||
|
|
||||||
|
|
||||||
|
@ -95,5 +95,5 @@ class PydanticModelGenerator:
|
||||||
StableDiffusionProcessingAPI = PydanticModelGenerator(
|
StableDiffusionProcessingAPI = PydanticModelGenerator(
|
||||||
"StableDiffusionProcessingTxt2Img",
|
"StableDiffusionProcessingTxt2Img",
|
||||||
StableDiffusionProcessingTxt2Img,
|
StableDiffusionProcessingTxt2Img,
|
||||||
[{"key": "sampler_index", "type": str, "default": "k_euler_a"}]
|
[{"key": "sampler_index", "type": str, "default": "Euler"}]
|
||||||
).generate_model()
|
).generate_model()
|
Loading…
Reference in New Issue
Block a user