Merge 5ca3ecc9ed
into ea9bd9fc74
This commit is contained in:
commit
dd45d5f9a3
|
@ -467,7 +467,7 @@ class Api:
|
|||
shared.state.end()
|
||||
return PreprocessResponse(info = 'preprocess error: {error}'.format(error = e))
|
||||
|
||||
def train_embedding(self, args: dict):
|
||||
def train_embedding(self, args: TrainEmbeddingAPI):
|
||||
try:
|
||||
shared.state.begin()
|
||||
apply_optimizations = shared.opts.training_xattention_optimizations
|
||||
|
@ -476,7 +476,8 @@ class Api:
|
|||
if not apply_optimizations:
|
||||
sd_hijack.undo_optimizations()
|
||||
try:
|
||||
embedding, filename = train_embedding(**args) # can take a long time to complete
|
||||
training_args = args.__dict__
|
||||
embedding, filename = train_embedding(**training_args) # can take a long time to complete
|
||||
except Exception as e:
|
||||
error = e
|
||||
finally:
|
||||
|
|
|
@ -267,3 +267,90 @@ class EmbeddingsResponse(BaseModel):
|
|||
class MemoryResponse(BaseModel):
|
||||
ram: dict = Field(title="RAM", description="System memory stats")
|
||||
cuda: dict = Field(title="CUDA", description="nVidia CUDA memory stats")
|
||||
|
||||
class TrainEmbeddingAPI(BaseModel):
|
||||
embedding_name: str = Field(
|
||||
title="Embedding Name", description="Name of the embedding"
|
||||
)
|
||||
learn_rate: float = Field(
|
||||
title="Learning rate", description="Rate of learning", default=0.005
|
||||
)
|
||||
batch_size: int = Field(
|
||||
title="Batch size",
|
||||
description="How many images to create in a single batch",
|
||||
default=1,
|
||||
)
|
||||
gradient_step: int = Field(
|
||||
title="Gradient accumulation steps",
|
||||
description="Number of steps to accumulate the gradient",
|
||||
default=1,
|
||||
)
|
||||
data_root: str = Field(
|
||||
title="Dataset directory", description="Path to the directory with input images"
|
||||
)
|
||||
log_directory: str = Field(
|
||||
title="Log directory", description="", default="textual_inversion"
|
||||
)
|
||||
training_width: int = Field(title="Training width", description="", default=512)
|
||||
training_height: int = Field(title="Training height", description="", default=512)
|
||||
varsize: bool = Field(
|
||||
title="Do not resize images", description="Do not resize images", default=False
|
||||
)
|
||||
steps: int = Field(title="Max steps", description="", default=100000)
|
||||
clip_grad_mode: str = Field(
|
||||
title="Gradient clipping", description="Gradient clip mode", default="disabled"
|
||||
)
|
||||
clip_grad_value: float = Field(title="", description="", default=0.1)
|
||||
shuffle_tags: bool = Field(
|
||||
title="Shuffle tags",
|
||||
description="Shuffle tags by ',' when creating prompts",
|
||||
default=False,
|
||||
)
|
||||
tag_drop_out: bool = Field(
|
||||
title="Drop tags",
|
||||
description="Drop out tags when creating prompts",
|
||||
default=False,
|
||||
)
|
||||
latent_sampling_method: str = Field(
|
||||
title="Latent sampling method", description="", default="once"
|
||||
)
|
||||
create_image_every: int = Field(
|
||||
title="Create image every",
|
||||
description="Save an image to the log directory every N steps, 0 to disable",
|
||||
default=500,
|
||||
)
|
||||
save_embedding_every: int = Field(
|
||||
title="Save embedding",
|
||||
description="Save a copy of embedding to log directory every N steps, 0 to disable",
|
||||
default=500,
|
||||
)
|
||||
template_filename: str = Field(
|
||||
title="Prompt template",
|
||||
description="Prompt template file",
|
||||
default="style_filewords.txt",
|
||||
)
|
||||
save_image_with_stored_embedding: bool = Field(
|
||||
title="Save image with embedding",
|
||||
description="Save images with embedding in PNG chunks",
|
||||
default=True
|
||||
)
|
||||
preview_from_txt2img: bool = Field(
|
||||
title="Preview from txt2img",
|
||||
description="Read parameters (prompt, etc...) from txt2img tab when making previews",
|
||||
default=False,
|
||||
)
|
||||
preview_prompt: str = Field(title="Preview prompt", description="", default="")
|
||||
preview_negative_prompt: str = Field(
|
||||
title="Preview negative prompt", description="", default=""
|
||||
)
|
||||
preview_steps: int = Field(title="Preview steps", description="", default=20)
|
||||
preview_sampler_index: str = Field(
|
||||
title="Preview sampler", description="", default="Euler"
|
||||
)
|
||||
preview_cfg_scale: float = Field(
|
||||
title="Preview CFG scale", description="", default=7.0
|
||||
)
|
||||
preview_seed: float = Field(title="Preview seed", description="", default=-1.0)
|
||||
preview_width: int = Field(title="Preview width", description="", default=512)
|
||||
preview_height: int = Field(title="Preview height", description="", default=512)
|
||||
id_task: str = Field(title="ID Task", description="ID Task (unused)", default="")
|
||||
|
|
Loading…
Reference in New Issue
Block a user