diff --git a/modules/api/api.py b/modules/api/api.py index 25c65e57..2657f98b 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -466,7 +466,7 @@ class Api: shared.state.end() return PreprocessResponse(info = 'preprocess error: {error}'.format(error = e)) - def train_embedding(self, args: dict): + def train_embedding(self, args: TrainEmbeddingAPI): try: shared.state.begin() apply_optimizations = shared.opts.training_xattention_optimizations @@ -475,7 +475,8 @@ class Api: if not apply_optimizations: sd_hijack.undo_optimizations() try: - embedding, filename = train_embedding(**args) # can take a long time to complete + training_args = args.__dict__ + embedding, filename = train_embedding(**training_args) # can take a long time to complete except Exception as e: error = e finally: diff --git a/modules/api/models.py b/modules/api/models.py index 805bd8f7..ebf0d3c8 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -267,3 +267,85 @@ class EmbeddingsResponse(BaseModel): class MemoryResponse(BaseModel): ram: dict = Field(title="RAM", description="System memory stats") cuda: dict = Field(title="CUDA", description="nVidia CUDA memory stats") + +class TrainEmbeddingAPI(BaseModel): + id_task: str = Field(title="ID", description="ID Task") + embedding_name: str = Field( + title="Embedding Name", description="Name of the embedding" + ) + learn_rate: float = Field( + title="Learning rate", description="Rate of learning", default=0.005 + ) + batch_size: int = Field( + title="Batch size", + description="How many images to create in a single batch", + default=1, + ) + gradient_step: int = Field( + title="Gradient accumulation steps", + description="Number of steps to accumulate the gradient", + default=1, + ) + data_root: str = Field( + title="Dataset directory", description="Path to the directory with input images" + ) + log_directory: str = Field( + title="Log directory", description="", default="textual_inversion" + ) + training_width: int = Field(title="Training width", description="", default=512) + training_height: int = Field(title="Training height", description="", default=512) + varsize: str = Field(title="", description="") + steps: int = Field(title="Max steps", description="") + clip_grad_mode: str = Field( + title="Gradient clipping", description="Gradient clip mode" + ) + clip_grad_value: float = Field(title="", description="", default=0.1) + shuffle_tags: bool = Field( + title="Shuffle tags", + description="Shuffle tags by ',' when creating prompts", + default=False, + ) + tag_drop_out: bool = Field( + title="Drop tags", + description="Drop out tags when creating prompts", + default=False, + ) + latent_sampling_method: str = Field( + title="Latent sampling method", description="", default="once" + ) + create_image_every: int = Field( + title="Create image every", + description="Save an image to the log directory every N steps, 0 to disable", + default=500, + ) + save_embedding_every: int = Field( + title="Save embedding", + description="Save a copy of embedding to log directory every N steps, 0 to disable", + default=500, + ) + template_filename: str = Field( + title="Prompt template", + description="Prompt template file", + default="style_filewords.txt", + ) + save_image_with_stored_embedding: str = Field( + title="Save image with embedding", + description="Save images with embedding in PNG chunks", + ) + preview_from_txt2img: bool = Field( + title="Preview from txt2img", + description="Read parameters (prompt, etc...) from txt2img tab when making previews", + default=False, + ) + preview_prompt: str = Field(title="Preview prompt", description="") + preview_negative_prompt: str = Field( + title="Preview negative prompt", description="" + ) + preview_steps: int = Field(title="Preview steps", description="", default=20) + preview_sampler_index: str = Field( + title="Preview sampler", description="", default="Euler" + ) + preview_cfg_scale: float = Field(title="Preview CFG scale", description="") + preview_seed: float = Field(title="Preview seed", description="") + preview_width: int = Field(title="Preview width", description="") + preview_height: int = Field(title="Preview height", description="")