This commit is contained in:
Dave Lage 2023-02-05 03:53:52 -07:00 committed by GitHub
commit dd45d5f9a3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 90 additions and 2 deletions

View File

@ -467,7 +467,7 @@ class Api:
shared.state.end()
return PreprocessResponse(info = 'preprocess error: {error}'.format(error = e))
def train_embedding(self, args: dict):
def train_embedding(self, args: TrainEmbeddingAPI):
try:
shared.state.begin()
apply_optimizations = shared.opts.training_xattention_optimizations
@ -476,7 +476,8 @@ class Api:
if not apply_optimizations:
sd_hijack.undo_optimizations()
try:
embedding, filename = train_embedding(**args) # can take a long time to complete
training_args = args.__dict__
embedding, filename = train_embedding(**training_args) # can take a long time to complete
except Exception as e:
error = e
finally:

View File

@ -267,3 +267,90 @@ class EmbeddingsResponse(BaseModel):
class MemoryResponse(BaseModel):
ram: dict = Field(title="RAM", description="System memory stats")
cuda: dict = Field(title="CUDA", description="nVidia CUDA memory stats")
class TrainEmbeddingAPI(BaseModel):
embedding_name: str = Field(
title="Embedding Name", description="Name of the embedding"
)
learn_rate: float = Field(
title="Learning rate", description="Rate of learning", default=0.005
)
batch_size: int = Field(
title="Batch size",
description="How many images to create in a single batch",
default=1,
)
gradient_step: int = Field(
title="Gradient accumulation steps",
description="Number of steps to accumulate the gradient",
default=1,
)
data_root: str = Field(
title="Dataset directory", description="Path to the directory with input images"
)
log_directory: str = Field(
title="Log directory", description="", default="textual_inversion"
)
training_width: int = Field(title="Training width", description="", default=512)
training_height: int = Field(title="Training height", description="", default=512)
varsize: bool = Field(
title="Do not resize images", description="Do not resize images", default=False
)
steps: int = Field(title="Max steps", description="", default=100000)
clip_grad_mode: str = Field(
title="Gradient clipping", description="Gradient clip mode", default="disabled"
)
clip_grad_value: float = Field(title="", description="", default=0.1)
shuffle_tags: bool = Field(
title="Shuffle tags",
description="Shuffle tags by ',' when creating prompts",
default=False,
)
tag_drop_out: bool = Field(
title="Drop tags",
description="Drop out tags when creating prompts",
default=False,
)
latent_sampling_method: str = Field(
title="Latent sampling method", description="", default="once"
)
create_image_every: int = Field(
title="Create image every",
description="Save an image to the log directory every N steps, 0 to disable",
default=500,
)
save_embedding_every: int = Field(
title="Save embedding",
description="Save a copy of embedding to log directory every N steps, 0 to disable",
default=500,
)
template_filename: str = Field(
title="Prompt template",
description="Prompt template file",
default="style_filewords.txt",
)
save_image_with_stored_embedding: bool = Field(
title="Save image with embedding",
description="Save images with embedding in PNG chunks",
default=True
)
preview_from_txt2img: bool = Field(
title="Preview from txt2img",
description="Read parameters (prompt, etc...) from txt2img tab when making previews",
default=False,
)
preview_prompt: str = Field(title="Preview prompt", description="", default="")
preview_negative_prompt: str = Field(
title="Preview negative prompt", description="", default=""
)
preview_steps: int = Field(title="Preview steps", description="", default=20)
preview_sampler_index: str = Field(
title="Preview sampler", description="", default="Euler"
)
preview_cfg_scale: float = Field(
title="Preview CFG scale", description="", default=7.0
)
preview_seed: float = Field(title="Preview seed", description="", default=-1.0)
preview_width: int = Field(title="Preview width", description="", default=512)
preview_height: int = Field(title="Preview height", description="", default=512)
id_task: str = Field(title="ID Task", description="ID Task (unused)", default="")