diff --git a/modules/api/api.py b/modules/api/api.py index eb7b1da5..7e802208 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -467,7 +467,7 @@ class Api: shared.state.end() return PreprocessResponse(info = 'preprocess error: {error}'.format(error = e)) - def train_embedding(self, args: dict): + def train_embedding(self, args: TrainEmbeddingAPI): try: shared.state.begin() apply_optimizations = shared.opts.training_xattention_optimizations @@ -476,7 +476,8 @@ class Api: if not apply_optimizations: sd_hijack.undo_optimizations() try: - embedding, filename = train_embedding(**args) # can take a long time to complete + training_args = args.__dict__ + embedding, filename = train_embedding(**training_args) # can take a long time to complete except Exception as e: error = e finally: diff --git a/modules/api/models.py b/modules/api/models.py index cba43d3b..50a943d6 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -267,3 +267,90 @@ class EmbeddingsResponse(BaseModel): class MemoryResponse(BaseModel): ram: dict = Field(title="RAM", description="System memory stats") cuda: dict = Field(title="CUDA", description="nVidia CUDA memory stats") + +class TrainEmbeddingAPI(BaseModel): + embedding_name: str = Field( + title="Embedding Name", description="Name of the embedding" + ) + learn_rate: float = Field( + title="Learning rate", description="Rate of learning", default=0.005 + ) + batch_size: int = Field( + title="Batch size", + description="How many images to create in a single batch", + default=1, + ) + gradient_step: int = Field( + title="Gradient accumulation steps", + description="Number of steps to accumulate the gradient", + default=1, + ) + data_root: str = Field( + title="Dataset directory", description="Path to the directory with input images" + ) + log_directory: str = Field( + title="Log directory", description="", default="textual_inversion" + ) + training_width: int = Field(title="Training width", description="", default=512) + training_height: int = Field(title="Training height", description="", default=512) + varsize: bool = Field( + title="Do not resize images", description="Do not resize images", default=False + ) + steps: int = Field(title="Max steps", description="", default=100000) + clip_grad_mode: str = Field( + title="Gradient clipping", description="Gradient clip mode", default="disabled" + ) + clip_grad_value: float = Field(title="", description="", default=0.1) + shuffle_tags: bool = Field( + title="Shuffle tags", + description="Shuffle tags by ',' when creating prompts", + default=False, + ) + tag_drop_out: bool = Field( + title="Drop tags", + description="Drop out tags when creating prompts", + default=False, + ) + latent_sampling_method: str = Field( + title="Latent sampling method", description="", default="once" + ) + create_image_every: int = Field( + title="Create image every", + description="Save an image to the log directory every N steps, 0 to disable", + default=500, + ) + save_embedding_every: int = Field( + title="Save embedding", + description="Save a copy of embedding to log directory every N steps, 0 to disable", + default=500, + ) + template_filename: str = Field( + title="Prompt template", + description="Prompt template file", + default="style_filewords.txt", + ) + save_image_with_stored_embedding: bool = Field( + title="Save image with embedding", + description="Save images with embedding in PNG chunks", + default=True + ) + preview_from_txt2img: bool = Field( + title="Preview from txt2img", + description="Read parameters (prompt, etc...) from txt2img tab when making previews", + default=False, + ) + preview_prompt: str = Field(title="Preview prompt", description="", default="") + preview_negative_prompt: str = Field( + title="Preview negative prompt", description="", default="" + ) + preview_steps: int = Field(title="Preview steps", description="", default=20) + preview_sampler_index: str = Field( + title="Preview sampler", description="", default="Euler" + ) + preview_cfg_scale: float = Field( + title="Preview CFG scale", description="", default=7.0 + ) + preview_seed: float = Field(title="Preview seed", description="", default=-1.0) + preview_width: int = Field(title="Preview width", description="", default=512) + preview_height: int = Field(title="Preview height", description="", default=512) + id_task: str = Field(title="ID Task", description="ID Task (unused)", default="")