Implement PR #3189 but for embeddings.

This commit is contained in:
timntorres 2022-10-24 23:22:58 -07:00 committed by AUTOMATIC1111
parent a524d137d0
commit c2dc9bfa89

View File

@ -10,7 +10,7 @@ import csv
from PIL import Image, PngImagePlugin
from modules import shared, devices, sd_hijack, processing, sd_models
from modules import shared, devices, sd_hijack, processing, sd_models, images
import modules.textual_inversion.dataset
from modules.textual_inversion.learn_schedule import LearnRateScheduler
@ -247,6 +247,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
last_saved_file = "<none>"
last_saved_image = "<none>"
forced_filename = "<none>"
embedding_yet_to_be_embedded = False
ititial_step = embedding.step or 0
@ -296,8 +297,8 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
})
if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0:
last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png')
forced_filename = f'{embedding_name}-{embedding.step}'
last_saved_image = os.path.join(images_dir, forced_filename)
p = processing.StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
do_not_save_grid=True,
@ -353,8 +354,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info)
embedding_yet_to_be_embedded = False
image.save(last_saved_image)
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename)
last_saved_image += f", prompt: {preview_text}"
shared.state.job_no = embedding.step