From 29e74d6e71826da9a3fe3c5790fed1329fc4d1e8 Mon Sep 17 00:00:00 2001 From: Melan Date: Thu, 20 Oct 2022 16:26:16 +0200 Subject: [PATCH 01/38] Add support for Tensorboard for training embeddings --- modules/shared.py | 4 +++ .../textual_inversion/textual_inversion.py | 31 ++++++++++++++++++- 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/modules/shared.py b/modules/shared.py index faede821..2c6341f7 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -254,6 +254,10 @@ options_templates.update(options_section(('training', "Training"), { "dataset_filename_join_string": OptionInfo(" ", "Filename join string"), "training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}), "training_write_csv_every": OptionInfo(500, "Save an csv containing the loss to log directory every N steps, 0 to disable"), + "training_enable_tensorboard": OptionInfo(False, "Enable tensorboard logging."), + "training_tensorboard_save_images": OptionInfo(False, "Save generated images within tensorboard."), + "training_tensorboard_flush_every": OptionInfo(120, "How often, in seconds, to flush the pending tensorboard events and summaries to disk."), + })) options_templates.update(options_section(('sd', "Stable Diffusion"), { diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 3be69562..c57d3ace 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -7,9 +7,11 @@ import tqdm import html import datetime import csv +import numpy as np +import torchvision.transforms from PIL import Image, PngImagePlugin - +from torch.utils.tensorboard import SummaryWriter from modules import shared, devices, sd_hijack, processing, sd_models import modules.textual_inversion.dataset from modules.textual_inversion.learn_schedule import LearnRateScheduler @@ -199,6 +201,19 @@ def write_loss(log_directory, filename, step, epoch_len, values): **values, }) +def tensorboard_add_scaler(tensorboard_writer, tag, value, step): + if shared.opts.training_enable_tensorboard: + tensorboard_writer.add_scalar(tag=tag, + scalar_value=value, global_step=step) + +def tensorboard_add_image(tensorboard_writer, tag, pil_image, step): + if shared.opts.training_enable_tensorboard: + # Convert a pil image to a torch tensor + img_tensor = torch.as_tensor(np.array(pil_image, copy=True)) + img_tensor = img_tensor.view(pil_image.size[1], pil_image.size[0], len(pil_image.getbands())) + img_tensor = img_tensor.permute((2, 0, 1)) + + tensorboard_writer.add_image(tag, img_tensor, global_step=step) def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): assert embedding_name, 'embedding not selected' @@ -252,6 +267,12 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc scheduler = LearnRateScheduler(learn_rate, steps, ititial_step) optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate) + if shared.opts.training_enable_tensorboard: + os.makedirs(os.path.join(log_directory, "tensorboard"), exist_ok=True) + tensorboard_writer = SummaryWriter( + log_dir=os.path.join(log_directory, "tensorboard"), + flush_secs=shared.opts.training_tensorboard_flush_every) + pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step) for i, entries in pbar: embedding.step = i + ititial_step @@ -270,6 +291,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc del x losses[embedding.step % losses.shape[0]] = loss.item() + optimizer.zero_grad() loss.backward() @@ -285,6 +307,12 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc embedding.save(last_saved_file) embedding_yet_to_be_embedded = True + if shared.opts.training_enable_tensorboard: + tensorboard_add_scaler(tensorboard_writer, "Loss/train", losses.mean(), embedding.step) + tensorboard_add_scaler(tensorboard_writer, f"Loss/train/epoch-{epoch_num}", losses.mean(), epoch_step) + tensorboard_add_scaler(tensorboard_writer, "Learn rate/train", scheduler.learn_rate, embedding.step) + tensorboard_add_scaler(tensorboard_writer, f"Learn rate/train/epoch-{epoch_num}", scheduler.learn_rate, epoch_step) + write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, len(ds), { "loss": f"{losses.mean():.7f}", "learn_rate": scheduler.learn_rate @@ -349,6 +377,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc embedding_yet_to_be_embedded = False image.save(last_saved_image) + tensorboard_add_image(tensorboard_writer, f"Validation at epoch {epoch_num}", image, embedding.step) last_saved_image += f", prompt: {preview_text}" From a6d593a6b51dc6a8443f2aa5c24caa391a04cd56 Mon Sep 17 00:00:00 2001 From: Melan Date: Thu, 20 Oct 2022 19:43:21 +0200 Subject: [PATCH 02/38] Fixed a typo in a variable --- modules/textual_inversion/textual_inversion.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index c57d3ace..ec8176bf 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -260,11 +260,11 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc last_saved_image = "" embedding_yet_to_be_embedded = False - ititial_step = embedding.step or 0 - if ititial_step > steps: + initial_step = embedding.step or 0 + if initial_step > steps: return embedding, filename - scheduler = LearnRateScheduler(learn_rate, steps, ititial_step) + scheduler = LearnRateScheduler(learn_rate, steps, initial_step) optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate) if shared.opts.training_enable_tensorboard: @@ -273,9 +273,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc log_dir=os.path.join(log_directory, "tensorboard"), flush_secs=shared.opts.training_tensorboard_flush_every) - pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step) + pbar = tqdm.tqdm(enumerate(ds), total=steps-initial_step) for i, entries in pbar: - embedding.step = i + ititial_step + embedding.step = i + initial_step scheduler.apply(optimizer, embedding.step) if scheduler.finished: From 8f5912984794c4c69e429c4636e984854d911b6a Mon Sep 17 00:00:00 2001 From: Melan Date: Thu, 20 Oct 2022 22:37:16 +0200 Subject: [PATCH 03/38] Some changes to the tensorboard code and hypernetwork support --- modules/hypernetworks/hypernetwork.py | 18 +++++++- .../textual_inversion/textual_inversion.py | 45 +++++++++++-------- 2 files changed, 44 insertions(+), 19 deletions(-) diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 74300122..5e919775 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -4,6 +4,7 @@ import html import os import sys import traceback +import tensorboard import tqdm import csv @@ -18,7 +19,6 @@ import modules.textual_inversion.dataset from modules.textual_inversion import textual_inversion from modules.textual_inversion.learn_schedule import LearnRateScheduler - class HypernetworkModule(torch.nn.Module): multiplier = 1.0 @@ -291,6 +291,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log scheduler = LearnRateScheduler(learn_rate, steps, ititial_step) optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate) + if shared.opts.training_enable_tensorboard: + tensorboard_writer = textual_inversion.tensorboard_setup(log_directory) + pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step) for i, entries in pbar: hypernetwork.step = i + ititial_step @@ -315,6 +318,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log optimizer.zero_grad() loss.backward() optimizer.step() + mean_loss = losses.mean() if torch.isnan(mean_loss): raise RuntimeError("Loss diverged.") @@ -323,6 +327,14 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log if hypernetwork.step > 0 and hypernetwork_dir is not None and hypernetwork.step % save_hypernetwork_every == 0: last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name}-{hypernetwork.step}.pt') hypernetwork.save(last_saved_file) + + if shared.opts.training_enable_tensorboard: + epoch_num = hypernetwork.step // len(ds) + epoch_step = hypernetwork.step - (epoch_num * len(ds)) + 1 + + textual_inversion.tensorboard_add(tensorboard_writer, loss=mean_loss, + global_step=hypernetwork.step, step=epoch_step, + learn_rate=scheduler.learn_rate, epoch_num=epoch_num) textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), { "loss": f"{mean_loss:.7f}", @@ -360,6 +372,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log processed = processing.process_images(p) image = processed.images[0] if len(processed.images)>0 else None + if shared.opts.training_enable_tensorboard and shared.opts.training_tensorboard_save_images: + textual_inversion.tensorboard_add_image(tensorboard_writer, f"Validation at epoch {epoch_num}", + image, hypernetwork.step) + if unload: shared.sd_model.cond_stage_model.to(devices.cpu) shared.sd_model.first_stage_model.to(devices.cpu) diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index ec8176bf..b1dc2596 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -201,19 +201,30 @@ def write_loss(log_directory, filename, step, epoch_len, values): **values, }) +def tensorboard_setup(log_directory): + os.makedirs(os.path.join(log_directory, "tensorboard"), exist_ok=True) + return SummaryWriter( + log_dir=os.path.join(log_directory, "tensorboard"), + flush_secs=shared.opts.training_tensorboard_flush_every) + +def tensorboard_add(tensorboard_writer, loss, global_step, step, learn_rate, epoch_num): + tensorboard_add_scaler(tensorboard_writer, "Loss/train", loss, global_step) + tensorboard_add_scaler(tensorboard_writer, f"Loss/train/epoch-{epoch_num}", loss, step) + tensorboard_add_scaler(tensorboard_writer, "Learn rate/train", learn_rate, global_step) + tensorboard_add_scaler(tensorboard_writer, f"Learn rate/train/epoch-{epoch_num}", learn_rate, step) + def tensorboard_add_scaler(tensorboard_writer, tag, value, step): - if shared.opts.training_enable_tensorboard: - tensorboard_writer.add_scalar(tag=tag, - scalar_value=value, global_step=step) + tensorboard_writer.add_scalar(tag=tag, + scalar_value=value, global_step=step) def tensorboard_add_image(tensorboard_writer, tag, pil_image, step): - if shared.opts.training_enable_tensorboard: - # Convert a pil image to a torch tensor - img_tensor = torch.as_tensor(np.array(pil_image, copy=True)) - img_tensor = img_tensor.view(pil_image.size[1], pil_image.size[0], len(pil_image.getbands())) - img_tensor = img_tensor.permute((2, 0, 1)) + # Convert a pil image to a torch tensor + img_tensor = torch.as_tensor(np.array(pil_image, copy=True)) + img_tensor = img_tensor.view(pil_image.size[1], pil_image.size[0], + len(pil_image.getbands())) + img_tensor = img_tensor.permute((2, 0, 1)) - tensorboard_writer.add_image(tag, img_tensor, global_step=step) + tensorboard_writer.add_image(tag, img_tensor, global_step=step) def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): assert embedding_name, 'embedding not selected' @@ -268,10 +279,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate) if shared.opts.training_enable_tensorboard: - os.makedirs(os.path.join(log_directory, "tensorboard"), exist_ok=True) - tensorboard_writer = SummaryWriter( - log_dir=os.path.join(log_directory, "tensorboard"), - flush_secs=shared.opts.training_tensorboard_flush_every) + tensorboard_writer = tensorboard_setup(log_directory) pbar = tqdm.tqdm(enumerate(ds), total=steps-initial_step) for i, entries in pbar: @@ -308,10 +316,8 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc embedding_yet_to_be_embedded = True if shared.opts.training_enable_tensorboard: - tensorboard_add_scaler(tensorboard_writer, "Loss/train", losses.mean(), embedding.step) - tensorboard_add_scaler(tensorboard_writer, f"Loss/train/epoch-{epoch_num}", losses.mean(), epoch_step) - tensorboard_add_scaler(tensorboard_writer, "Learn rate/train", scheduler.learn_rate, embedding.step) - tensorboard_add_scaler(tensorboard_writer, f"Learn rate/train/epoch-{epoch_num}", scheduler.learn_rate, epoch_step) + tensorboard_add(tensorboard_writer, loss=losses.mean(), global_step=embedding.step, + step=epoch_step, learn_rate=scheduler.learn_rate, epoch_num=epoch_num) write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, len(ds), { "loss": f"{losses.mean():.7f}", @@ -377,7 +383,10 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc embedding_yet_to_be_embedded = False image.save(last_saved_image) - tensorboard_add_image(tensorboard_writer, f"Validation at epoch {epoch_num}", image, embedding.step) + + if shared.opts.training_enable_tensorboard and shared.opts.training_tensorboard_save_images: + tensorboard_add_image(tensorboard_writer, f"Validation at epoch {epoch_num}", + image, embedding.step) last_saved_image += f", prompt: {preview_text}" From 7543cf5e3b5eaced00582da257801227d1ff2a6e Mon Sep 17 00:00:00 2001 From: Melan Date: Thu, 20 Oct 2022 22:43:08 +0200 Subject: [PATCH 04/38] Fixed some typos in the code --- modules/hypernetworks/hypernetwork.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 5e919775..0cd94f49 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -284,19 +284,19 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log last_saved_file = "" last_saved_image = "" - ititial_step = hypernetwork.step or 0 - if ititial_step > steps: + initial_step = hypernetwork.step or 0 + if initial_step > steps: return hypernetwork, filename - scheduler = LearnRateScheduler(learn_rate, steps, ititial_step) + scheduler = LearnRateScheduler(learn_rate, steps, initial_step) optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate) if shared.opts.training_enable_tensorboard: tensorboard_writer = textual_inversion.tensorboard_setup(log_directory) - pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step) + pbar = tqdm.tqdm(enumerate(ds), total=steps - initial_step) for i, entries in pbar: - hypernetwork.step = i + ititial_step + hypernetwork.step = i + initial_step scheduler.apply(optimizer, hypernetwork.step) if scheduler.finished: From 18f86e41f6f289042c075bff1498e620ab997b8c Mon Sep 17 00:00:00 2001 From: Melan Date: Mon, 24 Oct 2022 17:21:18 +0200 Subject: [PATCH 05/38] Removed two unused imports --- modules/hypernetworks/hypernetwork.py | 1 - modules/textual_inversion/textual_inversion.py | 1 - 2 files changed, 2 deletions(-) diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 0cd94f49..2263e95e 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -4,7 +4,6 @@ import html import os import sys import traceback -import tensorboard import tqdm import csv diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index b1dc2596..589314fe 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -9,7 +9,6 @@ import datetime import csv import numpy as np -import torchvision.transforms from PIL import Image, PngImagePlugin from torch.utils.tensorboard import SummaryWriter from modules import shared, devices, sd_hijack, processing, sd_models From 6be644fa04ce1542f3a01804310cbbc0a4a91620 Mon Sep 17 00:00:00 2001 From: dan Date: Wed, 11 Jan 2023 05:31:58 +0800 Subject: [PATCH 06/38] Enable batch_size>1 for mixed-sized training --- modules/textual_inversion/dataset.py | 36 ++++++++++++++++++++++++---- 1 file changed, 32 insertions(+), 4 deletions(-) diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index fa48708e..b47414f3 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -3,8 +3,10 @@ import numpy as np import PIL import torch from PIL import Image -from torch.utils.data import Dataset, DataLoader +from torch.utils.data import Dataset, DataLoader, Sampler from torchvision import transforms +from collections import defaultdict +from random import shuffle, choices import random import tqdm @@ -45,12 +47,12 @@ class PersonalizedBase(Dataset): assert data_root, 'dataset directory not specified' assert os.path.isdir(data_root), "Dataset directory doesn't exist" assert os.listdir(data_root), "Dataset directory is empty" - assert batch_size == 1 or not varsize, 'variable img size must have batch size 1' self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)] self.shuffle_tags = shuffle_tags self.tag_drop_out = tag_drop_out + groups = defaultdict(list) print("Preparing dataset...") for path in tqdm.tqdm(self.image_paths): @@ -103,13 +105,14 @@ class PersonalizedBase(Dataset): if include_cond and not (self.tag_drop_out != 0 or self.shuffle_tags): with devices.autocast(): entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0) - + groups[image.size].append(len(self.dataset)) self.dataset.append(entry) del torchdata del latent_dist del latent_sample self.length = len(self.dataset) + self.groups = list(groups.values()) assert self.length > 0, "No images have been found in the dataset." self.batch_size = min(batch_size, self.length) self.gradient_step = min(gradient_step, self.length // self.batch_size) @@ -137,9 +140,34 @@ class PersonalizedBase(Dataset): entry.latent_sample = shared.sd_model.get_first_stage_encoding(entry.latent_dist).to(devices.cpu) return entry +class GroupedBatchSampler(Sampler): + def __init__(self, data_source: PersonalizedBase, batch_size: int): + n = len(data_source) + self.groups = data_source.groups + self.len = n_batch = n // batch_size + expected = [len(g) / n * n_batch * batch_size for g in data_source.groups] + self.base = [int(e) // batch_size for e in expected] + self.n_rand_batches = nrb = n_batch - sum(self.base) + self.probs = [e%batch_size/nrb/batch_size if nrb>0 else 0 for e in expected] + self.batch_size = batch_size + def __len__(self): + return self.len + def __iter__(self): + b = self.batch_size + for g in self.groups: + shuffle(g) + batches = [] + for g in self.groups: + batches.extend(g[i*b:(i+1)*b] for i in range(len(g) // b)) + for _ in range(self.n_rand_batches): + rand_group = choices(self.groups, self.probs)[0] + batches.append(choices(rand_group, k=b)) + shuffle(batches) + yield from batches + class PersonalizedDataLoader(DataLoader): def __init__(self, dataset, latent_sampling_method="once", batch_size=1, pin_memory=False): - super(PersonalizedDataLoader, self).__init__(dataset, shuffle=True, drop_last=True, batch_size=batch_size, pin_memory=pin_memory) + super(PersonalizedDataLoader, self).__init__(dataset, batch_sampler=GroupedBatchSampler(dataset, batch_size), pin_memory=pin_memory) if latent_sampling_method == "random": self.collate_fn = collate_wrapper_random else: From f9706acf431f77e0ce9e4270e5be7299922ee963 Mon Sep 17 00:00:00 2001 From: Lee Bousfield Date: Tue, 10 Jan 2023 18:40:34 -0700 Subject: [PATCH 07/38] Support loading textual inversion embeddings from safetensors files --- modules/textual_inversion/textual_inversion.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 5420903f..3866c154 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -9,6 +9,7 @@ import tqdm import html import datetime import csv +import safetensors.torch from PIL import Image, PngImagePlugin @@ -150,6 +151,8 @@ class EmbeddingDatabase: name = data.get('name', name) elif ext in ['.BIN', '.PT']: data = torch.load(path, map_location="cpu") + elif ext in ['.SAFETENSORS']: + data = safetensors.torch.load_file(path, device="cpu") else: return From 5830095b73515fc49b3fd567048470005191ec34 Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Tue, 10 Jan 2023 21:43:24 -0500 Subject: [PATCH 08/38] Add old prompt parser compat option --- modules/shared.py | 1 + 1 file changed, 1 insertion(+) diff --git a/modules/shared.py b/modules/shared.py index 264264a6..b61bbd3f 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -400,6 +400,7 @@ options_templates.update(options_section(('compatibility', "Compatibility"), { "use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."), "use_old_karras_scheduler_sigmas": OptionInfo(False, "Use old karras scheduler sigmas (0.1 to 10)."), "use_old_hires_fix_width_height": OptionInfo(False, "For hires fix, use width/height sliders to set final resolution rather than first pass (disables Upscale by, Resize width/height to)."), + "use_old_prompt_parser_default_step_transformer": OptionInfo(False, "Use old prompt parser default step transformer. In particular, alternating words that contained emphasis were not parsed correctly. Useful to reproduce old seeds."), })) options_templates.update(options_section(('interrogate', "Interrogate Options"), { From 7e45fba55b24166501033a221e6268545fa47fbe Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Tue, 10 Jan 2023 21:47:03 -0500 Subject: [PATCH 09/38] Fix prompt parser default step transformer w/ test --- modules/prompt_parser.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py index f70872c4..b69f1425 100644 --- a/modules/prompt_parser.py +++ b/modules/prompt_parser.py @@ -3,6 +3,11 @@ from collections import namedtuple from typing import List import lark +try: + from modules.shared import opts +except: + pass + # a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]" # will be represented with prompt_schedule like this (assuming steps=100): # [25, 'fantasy landscape with a mountain and an oak in foreground shoddy'] @@ -49,6 +54,8 @@ def get_learned_conditioning_prompt_schedules(prompts, steps): [[5, 'a c'], [10, 'a {b|d{ c']] >>> g("((a][:b:c [d:3]") [[3, '((a][:b:c '], [10, '((a][:b:c d']] + >>> g("[a|(b:1.1)]") + [[1, 'a'], [2, '(b:1.1)'], [3, 'a'], [4, '(b:1.1)'], [5, 'a'], [6, '(b:1.1)'], [7, 'a'], [8, '(b:1.1)'], [9, 'a'], [10, '(b:1.1)']] """ def collect_steps(steps, tree): @@ -84,7 +91,13 @@ def get_learned_conditioning_prompt_schedules(prompts, steps): yield args[0].value def __default__(self, data, children, meta): for child in children: - yield from child + try: + if opts.use_old_prompt_parser_default_step_transformer: + yield from child + else: + yield child + except: + yield child return AtStep().transform(tree) def get_schedule(prompt): From 37a230112198adcb3f24d59b399cff342a6d479e Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Tue, 10 Jan 2023 20:30:09 -0800 Subject: [PATCH 10/38] Expose the compiled class module of scripts to extensions --- modules/scripts.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/modules/scripts.py b/modules/scripts.py index 35164093..4ffc369b 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -152,7 +152,7 @@ def basedir(): scripts_data = [] ScriptFile = namedtuple("ScriptFile", ["basedir", "filename", "path"]) -ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir"]) +ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir", "module"]) def list_scripts(scriptdirname, extension): @@ -206,7 +206,7 @@ def load_scripts(): for key, script_class in module.__dict__.items(): if type(script_class) == type and issubclass(script_class, Script): - scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir)) + scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module)) except Exception: print(f"Error loading script: {scriptfile.filename}", file=sys.stderr) @@ -241,7 +241,7 @@ class ScriptRunner: self.alwayson_scripts.clear() self.selectable_scripts.clear() - for script_class, path, basedir in scripts_data: + for script_class, path, basedir, script_module in scripts_data: script = script_class() script.filename = path script.is_txt2img = not is_img2img From 954091697fce7a1b7997d5f3d73551f793f6bebc Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 11 Jan 2023 09:10:07 +0300 Subject: [PATCH 11/38] add an option to copy config from one of models in checkpoint merger --- modules/extras.py | 30 +++++++++++++++++++++++++++++- modules/ui.py | 9 ++++++--- 2 files changed, 35 insertions(+), 4 deletions(-) diff --git a/modules/extras.py b/modules/extras.py index 7407bfe3..a03d558e 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -3,6 +3,7 @@ import math import os import sys import traceback +import shutil import numpy as np from PIL import Image @@ -248,7 +249,32 @@ def run_pnginfo(image): return '', geninfo, info -def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format): +def create_config(ckpt_result, config_source, a, b, c): + def config(x): + return sd_models.find_checkpoint_config(x) if x else None + + if config_source == 0: + cfg = config(a) or config(b) or config(c) + elif config_source == 1: + cfg = config(b) + elif config_source == 2: + cfg = config(c) + else: + cfg = None + + if cfg is None: + return + + filename, _ = os.path.splitext(ckpt_result) + checkpoint_filename = filename + ".yaml" + + print("Copying config:") + print(" from:", cfg) + print(" to:", checkpoint_filename) + shutil.copyfile(cfg, checkpoint_filename) + + +def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source): shared.state.begin() shared.state.job = 'model-merge' @@ -356,6 +382,8 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam sd_models.list_models() + create_config(output_modelname, config_source, primary_model_info, secondary_model_info, tertiary_model_info) + print("Checkpoint saved.") shared.state.textinfo = "Checkpoint saved to " + output_modelname shared.state.end() diff --git a/modules/ui.py b/modules/ui.py index 3c458ce8..82f5dd7c 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1129,7 +1129,7 @@ def create_ui(): with gr.Column(variant='panel'): gr.HTML(value="

A merger of the two checkpoints will be generated in your checkpoint directory.

") - with gr.Row(): + with FormRow(): primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)") create_refresh_button(primary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_A") @@ -1143,11 +1143,13 @@ def create_ui(): interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3, elem_id="modelmerger_interp_amount") interp_method = gr.Radio(choices=["Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method", elem_id="modelmerger_interp_method") - with gr.Row(): + with FormRow(): checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="ckpt", label="Checkpoint format", elem_id="modelmerger_checkpoint_format") save_as_half = gr.Checkbox(value=False, label="Save as float16", elem_id="modelmerger_save_as_half") - modelmerger_merge = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary') + config_source = gr.Radio(choices=["A, B or C", "B", "C", "Don't"], value="A, B or C", label="Copy config from", type="index", elem_id="modelmerger_config_method") + + modelmerger_merge = gr.Button(elem_id="modelmerger_merge", value="Merge", variant='primary') with gr.Column(variant='panel'): submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False) @@ -1703,6 +1705,7 @@ def create_ui(): save_as_half, custom_name, checkpoint_format, + config_source, ], outputs=[ submit_result, From 4fdacd31e48c6a7a35c1c25c559932585e8addde Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 11 Jan 2023 10:24:56 +0300 Subject: [PATCH 12/38] possible fix for fallback for fast model creation from config --- modules/sd_models.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/modules/sd_models.py b/modules/sd_models.py index b5bc12f0..a0a8a909 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -337,6 +337,9 @@ def load_model(checkpoint_info=None): with sd_disable_initialization.DisableInitialization(): sd_model = instantiate_from_config(sd_config.model) except Exception as e: + pass + + if sd_model is None: print('Failed to create model quickly; will retry using slow method.', file=sys.stderr) sd_model = instantiate_from_config(sd_config.model) From 1a23dc32ac5e16fac10115cafd0b841abd06e59f Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 11 Jan 2023 10:34:36 +0300 Subject: [PATCH 13/38] possible fix for fallback for fast model creation from config, attempt 2 --- modules/sd_models.py | 1 + 1 file changed, 1 insertion(+) diff --git a/modules/sd_models.py b/modules/sd_models.py index a0a8a909..084ba7fa 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -333,6 +333,7 @@ def load_model(checkpoint_info=None): timer = Timer() + sd_model = None try: with sd_disable_initialization.DisableInitialization(): sd_model = instantiate_from_config(sd_config.model) From b202714b65aa2145ff965ed4f197ac1516093f34 Mon Sep 17 00:00:00 2001 From: Alexey Shirokov <40300551+demiurge-ash@users.noreply.github.com> Date: Wed, 11 Jan 2023 11:41:50 +0300 Subject: [PATCH 14/38] Fix keyboard navigation in modal image viewer --- javascript/imageviewer.js | 1 + 1 file changed, 1 insertion(+) diff --git a/javascript/imageviewer.js b/javascript/imageviewer.js index b7bc2fe1..1f29ad7b 100644 --- a/javascript/imageviewer.js +++ b/javascript/imageviewer.js @@ -151,6 +151,7 @@ function showGalleryImage() { e.addEventListener('mousedown', function (evt) { if(!opts.js_modal_lightbox || evt.button != 0) return; modalZoomSet(gradioApp().getElementById('modalImage'), opts.js_modal_lightbox_initially_zoomed) + evt.preventDefault() showModal(evt) }, true); } From ab388d6f8bf51338de1950b3907c324b0ff6a872 Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Wed, 11 Jan 2023 08:59:47 -0500 Subject: [PATCH 15/38] Remove compat option check for prompt parser --- modules/prompt_parser.py | 13 +------------ 1 file changed, 1 insertion(+), 12 deletions(-) diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py index b69f1425..870218db 100644 --- a/modules/prompt_parser.py +++ b/modules/prompt_parser.py @@ -3,11 +3,6 @@ from collections import namedtuple from typing import List import lark -try: - from modules.shared import opts -except: - pass - # a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]" # will be represented with prompt_schedule like this (assuming steps=100): # [25, 'fantasy landscape with a mountain and an oak in foreground shoddy'] @@ -91,13 +86,7 @@ def get_learned_conditioning_prompt_schedules(prompts, steps): yield args[0].value def __default__(self, data, children, meta): for child in children: - try: - if opts.use_old_prompt_parser_default_step_transformer: - yield from child - else: - yield child - except: - yield child + yield child return AtStep().transform(tree) def get_schedule(prompt): From 0b38b72d31ead82c7d0998a29e50da90073831f7 Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Wed, 11 Jan 2023 09:01:37 -0500 Subject: [PATCH 16/38] Remove compat option for prompt parser --- modules/shared.py | 1 - 1 file changed, 1 deletion(-) diff --git a/modules/shared.py b/modules/shared.py index b61bbd3f..264264a6 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -400,7 +400,6 @@ options_templates.update(options_section(('compatibility', "Compatibility"), { "use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."), "use_old_karras_scheduler_sigmas": OptionInfo(False, "Use old karras scheduler sigmas (0.1 to 10)."), "use_old_hires_fix_width_height": OptionInfo(False, "For hires fix, use width/height sliders to set final resolution rather than first pass (disables Upscale by, Resize width/height to)."), - "use_old_prompt_parser_default_step_transformer": OptionInfo(False, "Use old prompt parser default step transformer. In particular, alternating words that contained emphasis were not parsed correctly. Useful to reproduce old seeds."), })) options_templates.update(options_section(('interrogate', "Interrogate Options"), { From 39ea251945d70efcf9b59d44eb0e71269d754aa4 Mon Sep 17 00:00:00 2001 From: Vladimir Mandic Date: Wed, 11 Jan 2023 10:23:51 -0500 Subject: [PATCH 17/38] add textinfo to progress response --- modules/api/api.py | 4 ++-- modules/api/models.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/modules/api/api.py b/modules/api/api.py index 6c564ad8..5767ba90 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -286,7 +286,7 @@ class Api: # copy from check_progress_call of ui.py if shared.state.job_count == 0: - return ProgressResponse(progress=0, eta_relative=0, state=shared.state.dict()) + return ProgressResponse(progress=0, eta_relative=0, state=shared.state.dict(), textinfo=shared.state.textinfo) # avoid dividing zero progress = 0.01 @@ -308,7 +308,7 @@ class Api: if shared.state.current_image and not req.skip_current_image: current_image = encode_pil_to_base64(shared.state.current_image) - return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image) + return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image, textinfo=shared.state.textinfo) def interrogateapi(self, interrogatereq: InterrogateRequest): image_b64 = interrogatereq.image diff --git a/modules/api/models.py b/modules/api/models.py index 034b4aa0..c78095ca 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -168,6 +168,7 @@ class ProgressResponse(BaseModel): eta_relative: float = Field(title="ETA in secs") state: dict = Field(title="State", description="The current state snapshot") current_image: str = Field(default=None, title="Current image", description="The current image in base64 format. opts.show_progress_every_n_steps is required for this to work.") + textinfo: str = Field(default=None, title="Info text", description="Info text used by WebUI.") class InterrogateRequest(BaseModel): image: str = Field(default="", title="Image", description="Image to work on, must be a Base64 string containing the image's data.") From 3f43d8a966ba8462ba019a5ad573f94508cd45f8 Mon Sep 17 00:00:00 2001 From: Vladimir Mandic Date: Wed, 11 Jan 2023 10:28:55 -0500 Subject: [PATCH 18/38] set descriptions --- modules/hypernetworks/hypernetwork.py | 4 +++- modules/textual_inversion/preprocess.py | 7 ++++++- modules/textual_inversion/textual_inversion.py | 4 +++- 3 files changed, 12 insertions(+), 3 deletions(-) diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 300d3975..194679e8 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -619,7 +619,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, epoch_num = hypernetwork.step // steps_per_epoch epoch_step = hypernetwork.step % steps_per_epoch - pbar.set_description(f"[Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}") + description = f"Training hypernetwork [Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}" + pbar.set_description(description) + shared.state.textinfo = description if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0: # Before saving, change name to match current checkpoint. hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}' diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py index feb876c6..3c1042ad 100644 --- a/modules/textual_inversion/preprocess.py +++ b/modules/textual_inversion/preprocess.py @@ -135,7 +135,8 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre params.process_caption_deepbooru = process_caption_deepbooru params.preprocess_txt_action = preprocess_txt_action - for index, imagefile in enumerate(tqdm.tqdm(files)): + pbar = tqdm.tqdm(files) + for index, imagefile in enumerate(pbar): params.subindex = 0 filename = os.path.join(src, imagefile) try: @@ -143,6 +144,10 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre except Exception: continue + description = f"Preprocessing [Image {index}/{len(files)}]" + pbar.set_description(description) + shared.state.textinfo = description + params.src = filename existing_caption = None diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 3866c154..b915b091 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -476,7 +476,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ epoch_num = embedding.step // steps_per_epoch epoch_step = embedding.step % steps_per_epoch - pbar.set_description(f"[Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}") + description = f"Training textual inversion [Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}" + pbar.set_description(description) + shared.state.textinfo = description if embedding_dir is not None and steps_done % save_embedding_every == 0: # Before saving, change name to match current checkpoint. embedding_name_every = f'{embedding_name}-{steps_done}' From 4bd490727e156ff53107d53416d6b89be86f2a62 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 11 Jan 2023 18:54:04 +0300 Subject: [PATCH 19/38] fix for an error caused by skipping initialization, for realsies this time: TypeError: expected str, bytes or os.PathLike object, not NoneType --- modules/sd_disable_initialization.py | 71 +++++++++++++--------------- modules/sd_models.py | 1 + 2 files changed, 33 insertions(+), 39 deletions(-) diff --git a/modules/sd_disable_initialization.py b/modules/sd_disable_initialization.py index 088ac24b..c72d8efc 100644 --- a/modules/sd_disable_initialization.py +++ b/modules/sd_disable_initialization.py @@ -20,6 +20,19 @@ class DisableInitialization: ``` """ + def __init__(self): + self.replaced = [] + + def replace(self, obj, field, func): + original = getattr(obj, field, None) + if original is None: + return None + + self.replaced.append((obj, field, original)) + setattr(obj, field, func) + + return original + def __enter__(self): def do_nothing(*args, **kwargs): pass @@ -37,11 +50,14 @@ class DisableInitialization: def transformers_utils_hub_get_file_from_cache(original, url, *args, **kwargs): # this file is always 404, prevent making request - if url == 'https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/added_tokens.json': - raise transformers.utils.hub.EntryNotFoundError + if url == 'https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/added_tokens.json' or url == 'openai/clip-vit-large-patch14' and args[0] == 'added_tokens.json': + return None try: - return original(url, *args, local_files_only=True, **kwargs) + res = original(url, *args, local_files_only=True, **kwargs) + if res is None: + res = original(url, *args, local_files_only=False, **kwargs) + return res except Exception as e: return original(url, *args, local_files_only=False, **kwargs) @@ -54,42 +70,19 @@ class DisableInitialization: def transformers_configuration_utils_cached_file(url, *args, local_files_only=False, **kwargs): return transformers_utils_hub_get_file_from_cache(self.transformers_configuration_utils_cached_file, url, *args, **kwargs) - self.init_kaiming_uniform = torch.nn.init.kaiming_uniform_ - self.init_no_grad_normal = torch.nn.init._no_grad_normal_ - self.init_no_grad_uniform_ = torch.nn.init._no_grad_uniform_ - self.create_model_and_transforms = open_clip.create_model_and_transforms - self.CLIPTextModel_from_pretrained = ldm.modules.encoders.modules.CLIPTextModel.from_pretrained - self.transformers_modeling_utils_load_pretrained_model = getattr(transformers.modeling_utils.PreTrainedModel, '_load_pretrained_model', None) - self.transformers_tokenization_utils_base_cached_file = getattr(transformers.tokenization_utils_base, 'cached_file', None) - self.transformers_configuration_utils_cached_file = getattr(transformers.configuration_utils, 'cached_file', None) - self.transformers_utils_hub_get_from_cache = getattr(transformers.utils.hub, 'get_from_cache', None) - - torch.nn.init.kaiming_uniform_ = do_nothing - torch.nn.init._no_grad_normal_ = do_nothing - torch.nn.init._no_grad_uniform_ = do_nothing - open_clip.create_model_and_transforms = create_model_and_transforms_without_pretrained - ldm.modules.encoders.modules.CLIPTextModel.from_pretrained = CLIPTextModel_from_pretrained - if self.transformers_modeling_utils_load_pretrained_model is not None: - transformers.modeling_utils.PreTrainedModel._load_pretrained_model = transformers_modeling_utils_load_pretrained_model - if self.transformers_tokenization_utils_base_cached_file is not None: - transformers.tokenization_utils_base.cached_file = transformers_tokenization_utils_base_cached_file - if self.transformers_configuration_utils_cached_file is not None: - transformers.configuration_utils.cached_file = transformers_configuration_utils_cached_file - if self.transformers_utils_hub_get_from_cache is not None: - transformers.utils.hub.get_from_cache = transformers_utils_hub_get_from_cache + self.replace(torch.nn.init, 'kaiming_uniform_', do_nothing) + self.replace(torch.nn.init, '_no_grad_normal_', do_nothing) + self.replace(torch.nn.init, '_no_grad_uniform_', do_nothing) + self.create_model_and_transforms = self.replace(open_clip, 'create_model_and_transforms', create_model_and_transforms_without_pretrained) + self.CLIPTextModel_from_pretrained = self.replace(ldm.modules.encoders.modules.CLIPTextModel, 'from_pretrained', CLIPTextModel_from_pretrained) + self.transformers_modeling_utils_load_pretrained_model = self.replace(transformers.modeling_utils.PreTrainedModel, '_load_pretrained_model', transformers_modeling_utils_load_pretrained_model) + self.transformers_tokenization_utils_base_cached_file = self.replace(transformers.tokenization_utils_base, 'cached_file', transformers_tokenization_utils_base_cached_file) + self.transformers_configuration_utils_cached_file = self.replace(transformers.configuration_utils, 'cached_file', transformers_configuration_utils_cached_file) + self.transformers_utils_hub_get_from_cache = self.replace(transformers.utils.hub, 'get_from_cache', transformers_utils_hub_get_from_cache) def __exit__(self, exc_type, exc_val, exc_tb): - torch.nn.init.kaiming_uniform_ = self.init_kaiming_uniform - torch.nn.init._no_grad_normal_ = self.init_no_grad_normal - torch.nn.init._no_grad_uniform_ = self.init_no_grad_uniform_ - open_clip.create_model_and_transforms = self.create_model_and_transforms - ldm.modules.encoders.modules.CLIPTextModel.from_pretrained = self.CLIPTextModel_from_pretrained - if self.transformers_modeling_utils_load_pretrained_model is not None: - transformers.modeling_utils.PreTrainedModel._load_pretrained_model = self.transformers_modeling_utils_load_pretrained_model - if self.transformers_tokenization_utils_base_cached_file is not None: - transformers.utils.hub.cached_file = self.transformers_tokenization_utils_base_cached_file - if self.transformers_configuration_utils_cached_file is not None: - transformers.utils.hub.cached_file = self.transformers_configuration_utils_cached_file - if self.transformers_utils_hub_get_from_cache is not None: - transformers.utils.hub.get_from_cache = self.transformers_utils_hub_get_from_cache + for obj, field, original in self.replaced: + setattr(obj, field, original) + + self.replaced.clear() diff --git a/modules/sd_models.py b/modules/sd_models.py index 084ba7fa..c466f273 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -334,6 +334,7 @@ def load_model(checkpoint_info=None): timer = Timer() sd_model = None + try: with sd_disable_initialization.DisableInitialization(): sd_model = instantiate_from_config(sd_config.model) From 0b8911d883118daa54f7735c5b753b5575d9f943 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 11 Jan 2023 20:33:24 +0300 Subject: [PATCH 20/38] img2img UI rework: obsolete --gradio-img2img-tool --gradio-inpaint-tool and always show all tools each in own tab --- modules/img2img.py | 56 ++++++++++++------------ modules/shared.py | 4 +- modules/ui.py | 103 +++++++++++++++++++++++---------------------- style.css | 4 +- 4 files changed, 83 insertions(+), 84 deletions(-) diff --git a/modules/img2img.py b/modules/img2img.py index ca58b5d8..f62783c6 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -59,38 +59,34 @@ def process_batch(p, input_dir, output_dir, args): processed_image.save(os.path.join(output_dir, filename)) -def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, init_img_with_mask, init_img_with_mask_orig, init_img_inpaint, init_mask_inpaint, mask_mode, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args): - is_inpaint = mode == 1 - is_batch = mode == 2 +def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args): + is_batch = mode == 5 - if is_inpaint: - # Drawn mask - if mask_mode == 0: - is_mask_sketch = isinstance(init_img_with_mask, dict) - is_mask_paint = not is_mask_sketch - if is_mask_sketch: - # Sketch: mask iff. not transparent - image, mask = init_img_with_mask["image"], init_img_with_mask["mask"] - alpha_mask = ImageOps.invert(image.split()[-1]).convert('L').point(lambda x: 255 if x > 0 else 0, mode='1') - mask = ImageChops.lighter(alpha_mask, mask.convert('L')).convert('L') - else: - # Color-sketch: mask iff. painted over - image = init_img_with_mask - orig = init_img_with_mask_orig or init_img_with_mask - pred = np.any(np.array(image) != np.array(orig), axis=-1) - mask = Image.fromarray(pred.astype(np.uint8) * 255, "L") - mask = ImageEnhance.Brightness(mask).enhance(1 - mask_alpha / 100) - blur = ImageFilter.GaussianBlur(mask_blur) - image = Image.composite(image.filter(blur), orig, mask.filter(blur)) - - image = image.convert("RGB") - # Uploaded mask - else: - image = init_img_inpaint - mask = init_mask_inpaint - # No mask + if mode == 0: # img2img + image = init_img.convert("RGB") + mask = None + elif mode == 1: # img2img sketch + image = sketch.convert("RGB") + mask = None + elif mode == 2: # inpaint + image, mask = init_img_with_mask["image"], init_img_with_mask["mask"] + alpha_mask = ImageOps.invert(image.split()[-1]).convert('L').point(lambda x: 255 if x > 0 else 0, mode='1') + mask = ImageChops.lighter(alpha_mask, mask.convert('L')).convert('L') + image = image.convert("RGB") + elif mode == 3: # inpaint sketch + image = inpaint_color_sketch + orig = inpaint_color_sketch_orig or inpaint_color_sketch + pred = np.any(np.array(image) != np.array(orig), axis=-1) + mask = Image.fromarray(pred.astype(np.uint8) * 255, "L") + mask = ImageEnhance.Brightness(mask).enhance(1 - mask_alpha / 100) + blur = ImageFilter.GaussianBlur(mask_blur) + image = Image.composite(image.filter(blur), orig, mask.filter(blur)) + image = image.convert("RGB") + elif mode == 4: # inpaint upload mask + image = init_img_inpaint + mask = init_mask_inpaint else: - image = init_img + image = None mask = None # Use the EXIF orientation of photos taken by smartphones. diff --git a/modules/shared.py b/modules/shared.py index 264264a6..1c964237 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -74,8 +74,8 @@ parser.add_argument("--freeze-settings", action='store_true', help="disable edit parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default=os.path.join(script_path, 'config.json')) parser.add_argument("--gradio-debug", action='store_true', help="launch gradio with --debug option") parser.add_argument("--gradio-auth", type=str, help='set gradio authentication like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None) -parser.add_argument("--gradio-img2img-tool", type=str, help='gradio image uploader tool: can be either editor for ctopping, or color-sketch for drawing', choices=["color-sketch", "editor"], default="editor") -parser.add_argument("--gradio-inpaint-tool", type=str, choices=["sketch", "color-sketch"], default="sketch", help="gradio inpainting editor: can be either sketch to only blur/noise the input, or color-sketch to paint over it") +parser.add_argument("--gradio-img2img-tool", type=str, help='does not do anything') +parser.add_argument("--gradio-inpaint-tool", type=str, help="does not do anything") parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last") parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(script_path, 'styles.csv')) parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False) diff --git a/modules/ui.py b/modules/ui.py index 82f5dd7c..e86a624b 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -795,53 +795,67 @@ def create_ui(): with FormRow().style(equal_height=False): with gr.Column(variant='panel', elem_id="img2img_settings"): + with gr.Tabs(elem_id="mode_img2img"): + with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab") as tab_img2img: + init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool="editor", image_mode="RGBA").style(height=480) - with gr.Tabs(elem_id="mode_img2img") as tabs_img2img_mode: - with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab"): - init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_img2img_tool, image_mode="RGBA").style(height=480) + with gr.TabItem('Sketch', id='img2img_sketch', elem_id="img2img_img2img_sketch_tab") as tab_sketch: + sketch = gr.Image(label="Image for img2img", elem_id="img2img_sketch", show_label=False, source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA").style(height=480) - with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab"): - init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_inpaint_tool, image_mode="RGBA").style(height=480) - init_img_with_mask_orig = gr.State(None) + with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab") as tab_inpaint: + init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA").style(height=480) - use_color_sketch = cmd_opts.gradio_inpaint_tool == "color-sketch" - if use_color_sketch: - def update_orig(image, state): - if image is not None: - same_size = state is not None and state.size == image.size - has_exact_match = np.any(np.all(np.array(image) == np.array(state), axis=-1)) - edited = same_size and has_exact_match - return image if not edited or state is None else state + with gr.TabItem('Inpaint sketch', id='inpaint_sketch', elem_id="img2img_inpaint_sketch_tab") as tab_inpaint_color: + inpaint_color_sketch = gr.Image(label="Color sketch inpainting", show_label=False, elem_id="inpaint_sketch", source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA").style(height=480) + inpaint_color_sketch_orig = gr.State(None) - init_img_with_mask.change(update_orig, [init_img_with_mask, init_img_with_mask_orig], init_img_with_mask_orig) + def update_orig(image, state): + if image is not None: + same_size = state is not None and state.size == image.size + has_exact_match = np.any(np.all(np.array(image) == np.array(state), axis=-1)) + edited = same_size and has_exact_match + return image if not edited or state is None else state - init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_base") - init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_mask") + inpaint_color_sketch.change(update_orig, [inpaint_color_sketch, inpaint_color_sketch_orig], inpaint_color_sketch_orig) - with FormRow(): - mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id="img2img_mask_blur") - mask_alpha = gr.Slider(label="Mask transparency", interactive=use_color_sketch, visible=use_color_sketch, elem_id="img2img_mask_alpha") + with gr.TabItem('Inpaint upload', id='inpaint_upload', elem_id="img2img_inpaint_upload_tab") as tab_inpaint_upload: + init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", elem_id="img_inpaint_base") + init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", elem_id="img_inpaint_mask") - with FormRow(): - mask_mode = gr.Radio(label="Mask source", choices=["Draw mask", "Upload mask"], type="index", value="Draw mask", elem_id="mask_mode") - inpainting_mask_invert = gr.Radio(label='Mask mode', choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", elem_id="img2img_mask_mode") - - with FormRow(): - inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='original', type="index", elem_id="img2img_inpainting_fill") - - with FormRow(): - with gr.Column(): - inpaint_full_res = gr.Radio(label="Inpaint area", choices=["Whole picture", "Only masked"], type="index", value="Whole picture", elem_id="img2img_inpaint_full_res") - - with gr.Column(scale=4): - inpaint_full_res_padding = gr.Slider(label='Only masked padding, pixels', minimum=0, maximum=256, step=4, value=32, elem_id="img2img_inpaint_full_res_padding") - - with gr.TabItem('Batch img2img', id='batch', elem_id="img2img_batch_tab"): + with gr.TabItem('Batch', id='batch', elem_id="img2img_batch_tab") as tab_batch: hidden = '
Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else '' gr.HTML(f"

Process images in a directory on the same machine where the server is running.
Use an empty output directory to save pictures normally instead of writing to the output directory.{hidden}

") img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, elem_id="img2img_batch_input_dir") img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir") + with FormGroup(elem_id="inpaint_controls", visible=False) as inpaint_controls: + with FormRow(): + mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id="img2img_mask_blur") + mask_alpha = gr.Slider(label="Mask transparency", visible=False, elem_id="img2img_mask_alpha") + + with FormRow(): + inpainting_mask_invert = gr.Radio(label='Mask mode', choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", elem_id="img2img_mask_mode") + + with FormRow(): + inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='original', type="index", elem_id="img2img_inpainting_fill") + + with FormRow(): + with gr.Column(): + inpaint_full_res = gr.Radio(label="Inpaint area", choices=["Whole picture", "Only masked"], type="index", value="Whole picture", elem_id="img2img_inpaint_full_res") + + with gr.Column(scale=4): + inpaint_full_res_padding = gr.Slider(label='Only masked padding, pixels', minimum=0, maximum=256, step=4, value=32, elem_id="img2img_inpaint_full_res_padding") + + def select_img2img_tab(tab): + return gr.update(visible=tab in [2, 3, 4]), gr.update(visible=tab == 3), + + for i, elem in enumerate([tab_img2img, tab_sketch, tab_inpaint, tab_inpaint_color, tab_inpaint_upload, tab_batch]): + elem.select( + fn=lambda tab=i: select_img2img_tab(tab), + inputs=[], + outputs=[inpaint_controls, mask_alpha], + ) + with FormRow(): resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize") @@ -900,20 +914,6 @@ def create_ui(): ] ) - mask_mode.change( - lambda mode, img: { - init_img_with_mask: gr_show(mode == 0), - init_img_inpaint: gr_show(mode == 1), - init_mask_inpaint: gr_show(mode == 1), - }, - inputs=[mask_mode, init_img_with_mask], - outputs=[ - init_img_with_mask, - init_img_inpaint, - init_mask_inpaint, - ], - ) - img2img_args = dict( fn=wrap_gradio_gpu_call(modules.img2img.img2img, extra_outputs=[None, '', '']), _js="submit_img2img", @@ -924,11 +924,12 @@ def create_ui(): img2img_prompt_style, img2img_prompt_style2, init_img, + sketch, init_img_with_mask, - init_img_with_mask_orig, + inpaint_color_sketch, + inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, - mask_mode, steps, sampler_index, mask_blur, diff --git a/style.css b/style.css index ec5e4182..ffd6307f 100644 --- a/style.css +++ b/style.css @@ -557,7 +557,9 @@ canvas[key="mask"] { } #img2img_image, #img2img_image > .h-60, #img2img_image > .h-60 > div, #img2img_image > .h-60 > div > img, -img2maskimg, #img2maskimg > .h-60, #img2maskimg > .h-60 > div, #img2maskimg > .h-60 > div > img +#img2img_sketch, #img2img_sketch > .h-60, #img2img_sketch > .h-60 > div, #img2img_sketch > .h-60 > div > img, +#img2maskimg, #img2maskimg > .h-60, #img2maskimg > .h-60 > div, #img2maskimg > .h-60 > div > img, +#inpaint_sketch, #inpaint_sketch > .h-60, #inpaint_sketch > .h-60 > div, #inpaint_sketch > .h-60 > div > img { height: 480px !important; max-height: 480px !important; From d52a80f7f7da160c73afd067c8f1bf491391f994 Mon Sep 17 00:00:00 2001 From: Shondoit Date: Thu, 12 Jan 2023 09:22:29 +0100 Subject: [PATCH 21/38] Allow creation of zero vectors for TI --- modules/textual_inversion/textual_inversion.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index b915b091..853246a6 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -248,11 +248,14 @@ def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'): with devices.autocast(): cond_model([""]) # will send cond model to GPU if lowvram/medvram is active - embedded = cond_model.encode_embedding_init_text(init_text, num_vectors_per_token) + #cond_model expects at least some text, so we provide '*' as backup. + embedded = cond_model.encode_embedding_init_text(init_text or '*', num_vectors_per_token) vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device) - for i in range(num_vectors_per_token): - vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token] + #Only copy if we provided an init_text, otherwise keep vectors as zeros + if init_text: + for i in range(num_vectors_per_token): + vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token] # Remove illegal characters from name. name = "".join( x for x in name if (x.isalnum() or x in "._- ")) From d48dcbd2b29eab492d53d78f482356d78e5beb19 Mon Sep 17 00:00:00 2001 From: Shondoit Date: Thu, 12 Jan 2023 09:53:35 +0100 Subject: [PATCH 22/38] Add zero vector feature to hints.js Also add the note that some tokens might be skipped. Not everyone is aware of this. --- javascript/hints.js | 1 + 1 file changed, 1 insertion(+) diff --git a/javascript/hints.js b/javascript/hints.js index 856e1389..244bfde2 100644 --- a/javascript/hints.js +++ b/javascript/hints.js @@ -92,6 +92,7 @@ titles = { "Weighted sum": "Result = A * (1 - M) + B * M", "Add difference": "Result = A + (B - C) * M", + "Initialization text": "If the number of tokens is more than the number of vectors, some may be skipped.\nLeave the textbox empty to start with zeroed out vectors", "Learning rate": "How fast should training go. Low values will take longer to train, high values may fail to converge (not generate accurate results) and/or may break the embedding (This has happened if you see Loss: nan in the training info textbox. If this happens, you need to manually restore your embedding from an older not-broken backup).\n\nYou can set a single numeric value, or multiple learning rates using the syntax:\n\n rate_1:max_steps_1, rate_2:max_steps_2, ...\n\nEG: 0.005:100, 1e-3:1000, 1e-5\n\nWill train with rate of 0.005 for first 100 steps, then 1e-3 until 1000 steps, then 1e-5 for all remaining steps.", "Clip skip": "Early stopping parameter for CLIP model; 1 is stop at last layer as usual, 2 is stop at penultimate layer, etc.", From 5623a3e7b1beed61f3ae6829a05b7b861d70e203 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Thu, 12 Jan 2023 19:47:33 +0300 Subject: [PATCH 23/38] fix send to inpaint sending you to wrong place --- javascript/ui.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/javascript/ui.js b/javascript/ui.js index ee226927..a41dd26f 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -54,7 +54,7 @@ function switch_to_img2img(){ function switch_to_inpaint(){ gradioApp().querySelector('#tabs').querySelectorAll('button')[1].click(); - gradioApp().getElementById('mode_img2img').querySelectorAll('button')[1].click(); + gradioApp().getElementById('mode_img2img').querySelectorAll('button')[2].click(); return args_to_array(arguments); } From 6ffefdcc9f47b66cbc543690d97cbf8327f4ba58 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Thu, 12 Jan 2023 19:47:44 +0300 Subject: [PATCH 24/38] fix js errors when restarting UI --- script.js | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/script.js b/script.js index 0e117d06..21960d91 100644 --- a/script.js +++ b/script.js @@ -1,5 +1,6 @@ function gradioApp() { - const gradioShadowRoot = document.getElementsByTagName('gradio-app')[0].shadowRoot + const elems = document.getElementsByTagName('gradio-app') + const gradioShadowRoot = elems.length == 0 ? null : elems[0].shadowRoot return !!gradioShadowRoot ? gradioShadowRoot : document; } From 88416ab5ff787eec3b9962b43b5e544bb75fbad6 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Thu, 12 Jan 2023 13:46:59 -0800 Subject: [PATCH 25/38] Fix extension parameters not being saved to last used parameters --- modules/processing.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index f04a0e1e..ae04cab7 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -531,16 +531,16 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: def infotext(iteration=0, position_in_batch=0): return create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, comments, iteration, position_in_batch) - with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file: - processed = Processed(p, [], p.seed, "") - file.write(processed.infotext(p, 0)) - if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings: model_hijack.embedding_db.load_textual_inversion_embeddings() if p.scripts is not None: p.scripts.process(p) + with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file: + processed = Processed(p, [], p.seed, "") + file.write(processed.infotext(p, 0)) + infotexts = [] output_images = [] From 6c88eaed4f5efca54a882eb1f8f30f01f350332a Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Thu, 12 Jan 2023 13:50:09 -0800 Subject: [PATCH 26/38] Add script callback for fixing infotext parameters --- modules/generation_parameters_copypaste.py | 3 ++- modules/script_callbacks.py | 20 +++++++++++++++++++- 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index 620aa606..593d99ef 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -7,7 +7,7 @@ from pathlib import Path import gradio as gr from modules.shared import script_path -from modules import shared, ui_tempdir +from modules import shared, ui_tempdir, script_callbacks import tempfile from PIL import Image @@ -298,6 +298,7 @@ def connect_paste(button, paste_fields, input_comp, jsfunc=None): prompt = file.read() params = parse_generation_parameters(prompt) + script_callbacks.infotext_pasted_callback(prompt, params) res = [] for output, key in paste_fields: diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index 608c5300..a9e19236 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -2,7 +2,7 @@ import sys import traceback from collections import namedtuple import inspect -from typing import Optional +from typing import Optional, Dict, Any from fastapi import FastAPI from gradio import Blocks @@ -71,6 +71,7 @@ callback_map = dict( callbacks_before_component=[], callbacks_after_component=[], callbacks_image_grid=[], + callbacks_infotext_pasted=[], callbacks_script_unloaded=[], ) @@ -172,6 +173,14 @@ def image_grid_callback(params: ImageGridLoopParams): report_exception(c, 'image_grid') +def infotext_pasted_callback(infotext: str, params: Dict[str, Any]): + for c in callback_map['callbacks_infotext_pasted']: + try: + c.callback(infotext, params) + except Exception: + report_exception(c, 'infotext_pasted') + + def script_unloaded_callback(): for c in reversed(callback_map['callbacks_script_unloaded']): try: @@ -290,6 +299,15 @@ def on_image_grid(callback): add_callback(callback_map['callbacks_image_grid'], callback) +def on_infotext_pasted(callback): + """register a function to be called before applying an infotext. + The callback is called with two arguments: + - infotext: str - raw infotext. + - result: Dict[str, any] - parsed infotext parameters. + """ + add_callback(callback_map['callbacks_infotext_pasted'], callback) + + def on_script_unloaded(callback): """register a function to be called before the script is unloaded. Any hooks/hijacks/monkeying about that the script did should be reverted here""" From 0b262802b86a55c4f71faf377f2cb1aee2960b63 Mon Sep 17 00:00:00 2001 From: Josh R Date: Thu, 12 Jan 2023 17:31:05 -0800 Subject: [PATCH 27/38] add gradient settings to training settings log files --- modules/textual_inversion/logging.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/textual_inversion/logging.py b/modules/textual_inversion/logging.py index 8b1981d5..31e50b64 100644 --- a/modules/textual_inversion/logging.py +++ b/modules/textual_inversion/logging.py @@ -2,7 +2,7 @@ import datetime import json import os -saved_params_shared = {"model_name", "model_hash", "initial_step", "num_of_dataset_images", "learn_rate", "batch_size", "data_root", "log_directory", "training_width", "training_height", "steps", "create_image_every", "template_file"} +saved_params_shared = {"model_name", "model_hash", "initial_step", "num_of_dataset_images", "learn_rate", "batch_size", "clip_grad_mode", "clip_grad_value", "gradient_step", "data_root", "log_directory", "training_width", "training_height", "steps", "create_image_every", "template_file"} saved_params_ti = {"embedding_name", "num_vectors_per_token", "save_embedding_every", "save_image_with_stored_embedding"} saved_params_hypernet = {"hypernetwork_name", "layer_structure", "activation_func", "weight_init", "add_layer_norm", "use_dropout", "save_hypernetwork_every"} saved_params_all = saved_params_shared | saved_params_ti | saved_params_hypernet From a176d89487d92f5a5b152401e5c424b34ff43b96 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 13 Jan 2023 14:32:15 +0300 Subject: [PATCH 28/38] print bucket sizes for training without resizing images #6620 fix an error when generating a picture with embedding in it --- modules/textual_inversion/dataset.py | 16 ++++++++++++++++ modules/textual_inversion/image_embedding.py | 4 ++-- modules/textual_inversion/textual_inversion.py | 2 +- 3 files changed, 19 insertions(+), 3 deletions(-) diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index b47414f3..d31963d4 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -118,6 +118,12 @@ class PersonalizedBase(Dataset): self.gradient_step = min(gradient_step, self.length // self.batch_size) self.latent_sampling_method = latent_sampling_method + if len(groups) > 1: + print("Buckets:") + for (w, h), ids in sorted(groups.items(), key=lambda x: x[0]): + print(f" {w}x{h}: {len(ids)}") + print() + def create_text(self, filename_text): text = random.choice(self.lines) tags = filename_text.split(',') @@ -140,8 +146,11 @@ class PersonalizedBase(Dataset): entry.latent_sample = shared.sd_model.get_first_stage_encoding(entry.latent_dist).to(devices.cpu) return entry + class GroupedBatchSampler(Sampler): def __init__(self, data_source: PersonalizedBase, batch_size: int): + super().__init__(data_source) + n = len(data_source) self.groups = data_source.groups self.len = n_batch = n // batch_size @@ -150,21 +159,28 @@ class GroupedBatchSampler(Sampler): self.n_rand_batches = nrb = n_batch - sum(self.base) self.probs = [e%batch_size/nrb/batch_size if nrb>0 else 0 for e in expected] self.batch_size = batch_size + def __len__(self): return self.len + def __iter__(self): b = self.batch_size + for g in self.groups: shuffle(g) + batches = [] for g in self.groups: batches.extend(g[i*b:(i+1)*b] for i in range(len(g) // b)) for _ in range(self.n_rand_batches): rand_group = choices(self.groups, self.probs)[0] batches.append(choices(rand_group, k=b)) + shuffle(batches) + yield from batches + class PersonalizedDataLoader(DataLoader): def __init__(self, dataset, latent_sampling_method="once", batch_size=1, pin_memory=False): super(PersonalizedDataLoader, self).__init__(dataset, batch_sampler=GroupedBatchSampler(dataset, batch_size), pin_memory=pin_memory) diff --git a/modules/textual_inversion/image_embedding.py b/modules/textual_inversion/image_embedding.py index ea653806..5593f88c 100644 --- a/modules/textual_inversion/image_embedding.py +++ b/modules/textual_inversion/image_embedding.py @@ -76,10 +76,10 @@ def insert_image_data_embed(image, data): next_size = data_np_low.shape[0] + (h-(data_np_low.shape[0] % h)) next_size = next_size + ((h*d)-(next_size % (h*d))) - data_np_low.resize(next_size) + data_np_low = np.resize(data_np_low, next_size) data_np_low = data_np_low.reshape((h, -1, d)) - data_np_high.resize(next_size) + data_np_high = np.resize(data_np_high, next_size) data_np_high = data_np_high.reshape((h, -1, d)) edge_style = list(data['string_to_param'].values())[0].cpu().detach().numpy().tolist()[0][:1024] diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 853246a6..e23906ca 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -479,7 +479,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ epoch_num = embedding.step // steps_per_epoch epoch_step = embedding.step % steps_per_epoch - description = f"Training textual inversion [Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}" + description = f"Training textual inversion [Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}] loss: {loss_step:.7f}" pbar.set_description(description) shared.state.textinfo = description if embedding_dir is not None and steps_done % save_embedding_every == 0: From 82725f0ac439f7e3b67858d55900e95330bbd326 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 13 Jan 2023 15:04:37 +0300 Subject: [PATCH 29/38] fix a bug caused by merge --- modules/textual_inversion/textual_inversion.py | 1 + 1 file changed, 1 insertion(+) diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 85210b0e..6939efcc 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -11,6 +11,7 @@ import datetime import csv import safetensors.torch +import numpy as np from PIL import Image, PngImagePlugin from torch.utils.tensorboard import SummaryWriter From d753a9df952ea640acbce724e8153356c8b68424 Mon Sep 17 00:00:00 2001 From: Zaprudin Aleksey Date: Fri, 13 Jan 2023 22:25:33 +0500 Subject: [PATCH 30/38] fix progress bar behavior for "Prompts from file or textbox" script --- scripts/prompts_from_file.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/prompts_from_file.py b/scripts/prompts_from_file.py index 2751f98a..1fe10a7c 100644 --- a/scripts/prompts_from_file.py +++ b/scripts/prompts_from_file.py @@ -146,7 +146,7 @@ class Script(scripts.Script): else: args = {"prompt": line} - n_iter = args.get("n_iter", 1) + n_iter = args.get("n_iter", p.n_iter) if n_iter != 1: job_count += n_iter else: From a95f1353089bdeaccd7c266b40cdd79efedfe632 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 14 Jan 2023 09:56:59 +0300 Subject: [PATCH 31/38] change hash to sha256 --- .gitignore | 1 + modules/api/api.py | 2 +- modules/api/models.py | 3 +- modules/hashes.py | 72 +++++++++++ modules/hypernetworks/hypernetwork.py | 4 +- modules/sd_models.py | 118 +++++++++++------- modules/shared.py | 2 +- .../textual_inversion/textual_inversion.py | 6 +- webui.py | 2 + 9 files changed, 159 insertions(+), 51 deletions(-) create mode 100644 modules/hashes.py diff --git a/.gitignore b/.gitignore index 21fa26a7..0b1d17ca 100644 --- a/.gitignore +++ b/.gitignore @@ -32,3 +32,4 @@ notification.mp3 /extensions /test/stdout.txt /test/stderr.txt +/cache.json diff --git a/modules/api/api.py b/modules/api/api.py index 5767ba90..9814bbc2 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -371,7 +371,7 @@ class Api: return upscalers def get_sd_models(self): - return [{"title":x.title, "model_name":x.model_name, "hash":x.hash, "filename": x.filename, "config": find_checkpoint_config(x)} for x in checkpoints_list.values()] + return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config(x)} for x in checkpoints_list.values()] def get_hypernetworks(self): return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks] diff --git a/modules/api/models.py b/modules/api/models.py index c78095ca..1eb1fcf1 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -224,7 +224,8 @@ class UpscalerItem(BaseModel): class SDModelItem(BaseModel): title: str = Field(title="Title") model_name: str = Field(title="Model Name") - hash: str = Field(title="Hash") + hash: Optional[str] = Field(title="Short hash") + sha256: Optional[str] = Field(title="sha256 hash") filename: str = Field(title="Filename") config: str = Field(title="Config file") diff --git a/modules/hashes.py b/modules/hashes.py new file mode 100644 index 00000000..ebfbd90c --- /dev/null +++ b/modules/hashes.py @@ -0,0 +1,72 @@ +import hashlib +import json +import os.path + +import filelock + + +cache_filename = "cache.json" +cache_data = None + + +def dump_cache(): + with filelock.FileLock(cache_filename+".lock"): + with open(cache_filename, "w", encoding="utf8") as file: + json.dump(cache_data, file, indent=4) + + +def cache(subsection): + global cache_data + + if cache_data is None: + with filelock.FileLock(cache_filename+".lock"): + if not os.path.isfile(cache_filename): + cache_data = {} + else: + with open(cache_filename, "r", encoding="utf8") as file: + cache_data = json.load(file) + + s = cache_data.get(subsection, {}) + cache_data[subsection] = s + + return s + + +def calculate_sha256(filename): + hash_sha256 = hashlib.sha256() + + with open(filename, "rb") as f: + for chunk in iter(lambda: f.read(4096), b""): + hash_sha256.update(chunk) + + return hash_sha256.hexdigest() + + +def sha256(filename, title): + hashes = cache("hashes") + ondisk_mtime = os.path.getmtime(filename) + + if title in hashes: + cached_sha256 = hashes[title].get("sha256", None) + cached_mtime = hashes[title].get("mtime", 0) + + if ondisk_mtime <= cached_mtime and cached_sha256 is not None: + return cached_sha256 + + print(f"Calculating sha256 for {filename}: ", end='') + sha256_value = calculate_sha256(filename) + print(f"{sha256_value}") + + hashes[title] = { + "mtime": ondisk_mtime, + "sha256": sha256_value, + } + + dump_cache() + + return sha256_value + + + + + diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 83cbb4f0..9b5f2e79 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -509,7 +509,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, if shared.opts.save_training_settings_to_txt: saved_params = dict( - model_name=checkpoint.model_name, model_hash=checkpoint.hash, num_of_dataset_images=len(ds), + model_name=checkpoint.model_name, model_hash=checkpoint.shorthash, num_of_dataset_images=len(ds), **{field: getattr(hypernetwork, field) for field in ['layer_structure', 'activation_func', 'weight_init', 'add_layer_norm', 'use_dropout', ]} ) logging.save_settings_to_file(log_directory, {**saved_params, **locals()}) @@ -737,7 +737,7 @@ def save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename): old_sd_checkpoint = hypernetwork.sd_checkpoint if hasattr(hypernetwork, "sd_checkpoint") else None old_sd_checkpoint_name = hypernetwork.sd_checkpoint_name if hasattr(hypernetwork, "sd_checkpoint_name") else None try: - hypernetwork.sd_checkpoint = checkpoint.hash + hypernetwork.sd_checkpoint = checkpoint.shorthash hypernetwork.sd_checkpoint_name = checkpoint.model_name hypernetwork.name = hypernetwork_name hypernetwork.save(filename) diff --git a/modules/sd_models.py b/modules/sd_models.py index c466f273..7babb9ae 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -14,17 +14,56 @@ import ldm.modules.midas as midas from ldm.util import instantiate_from_config -from modules import shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors +from modules import shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes from modules.paths import models_path from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting model_dir = "Stable-diffusion" model_path = os.path.abspath(os.path.join(models_path, model_dir)) -CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name']) checkpoints_list = {} +checkpoint_alisases = {} checkpoints_loaded = collections.OrderedDict() + +class CheckpointInfo: + def __init__(self, filename): + self.filename = filename + abspath = os.path.abspath(filename) + + if shared.cmd_opts.ckpt_dir is not None and abspath.startswith(shared.cmd_opts.ckpt_dir): + name = abspath.replace(shared.cmd_opts.ckpt_dir, '') + elif abspath.startswith(model_path): + name = abspath.replace(model_path, '') + else: + name = os.path.basename(filename) + + if name.startswith("\\") or name.startswith("/"): + name = name[1:] + + self.title = name + self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0] + self.hash = model_hash(filename) + self.ids = [self.hash, self.model_name, self.title, f'{name} [{self.hash}]'] + self.shorthash = None + self.sha256 = None + + def register(self): + checkpoints_list[self.title] = self + for id in self.ids: + checkpoint_alisases[id] = self + + def calculate_shorthash(self): + self.sha256 = hashes.sha256(self.filename, self.title) + self.shorthash = self.sha256[0:10] + + if self.shorthash not in self.ids: + self.ids += [self.shorthash, self.sha256] + self.register() + + return self.shorthash + + try: # this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start. @@ -43,10 +82,14 @@ def setup_model(): enable_midas_autodownload() -def checkpoint_tiles(): - convert = lambda name: int(name) if name.isdigit() else name.lower() - alphanumeric_key = lambda key: [convert(c) for c in re.split('([0-9]+)', key)] - return sorted([x.title for x in checkpoints_list.values()], key = alphanumeric_key) +def checkpoint_tiles(): + def convert(name): + return int(name) if name.isdigit() else name.lower() + + def alphanumeric_key(key): + return [convert(c) for c in re.split('([0-9]+)', key)] + + return sorted([x.title for x in checkpoints_list.values()], key=alphanumeric_key) def find_checkpoint_config(info): @@ -62,48 +105,38 @@ def find_checkpoint_config(info): def list_models(): checkpoints_list.clear() + checkpoint_alisases.clear() model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"], ext_blacklist=[".vae.safetensors"]) - def modeltitle(path, shorthash): - abspath = os.path.abspath(path) - - if shared.cmd_opts.ckpt_dir is not None and abspath.startswith(shared.cmd_opts.ckpt_dir): - name = abspath.replace(shared.cmd_opts.ckpt_dir, '') - elif abspath.startswith(model_path): - name = abspath.replace(model_path, '') - else: - name = os.path.basename(path) - - if name.startswith("\\") or name.startswith("/"): - name = name[1:] - - shortname = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0] - - return f'{name} [{shorthash}]', shortname - cmd_ckpt = shared.cmd_opts.ckpt if os.path.exists(cmd_ckpt): - h = model_hash(cmd_ckpt) - title, short_model_name = modeltitle(cmd_ckpt, h) - checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name) - shared.opts.data['sd_model_checkpoint'] = title + checkpoint_info = CheckpointInfo(cmd_ckpt) + checkpoint_info.register() + + shared.opts.data['sd_model_checkpoint'] = checkpoint_info.title elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file: print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr) + for filename in model_list: - h = model_hash(filename) - title, short_model_name = modeltitle(filename, h) - - checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name) + checkpoint_info = CheckpointInfo(filename) + checkpoint_info.register() -def get_closet_checkpoint_match(searchString): - applicable = sorted([info for info in checkpoints_list.values() if searchString in info.title], key = lambda x:len(x.title)) - if len(applicable) > 0: - return applicable[0] +def get_closet_checkpoint_match(search_string): + checkpoint_info = checkpoint_alisases.get(search_string, None) + if checkpoint_info is not None: + return + + found = sorted([info for info in checkpoints_list.values() if search_string in info.title], key=lambda x: len(x.title)) + if found: + return found[0] + return None def model_hash(filename): + """old hash that only looks at a small part of the file and is prone to collisions""" + try: with open(filename, "rb") as file: import hashlib @@ -119,7 +152,7 @@ def model_hash(filename): def select_checkpoint(): model_checkpoint = shared.opts.sd_model_checkpoint - checkpoint_info = checkpoints_list.get(model_checkpoint, None) + checkpoint_info = checkpoint_alisases.get(model_checkpoint, None) if checkpoint_info is not None: return checkpoint_info @@ -189,9 +222,8 @@ def read_state_dict(checkpoint_file, print_global_state=False, map_location=None return sd -def load_model_weights(model, checkpoint_info, vae_file="auto"): - checkpoint_file = checkpoint_info.filename - sd_model_hash = checkpoint_info.hash +def load_model_weights(model, checkpoint_info: CheckpointInfo, vae_file="auto"): + sd_model_hash = checkpoint_info.calculate_shorthash() cache_enabled = shared.opts.sd_checkpoint_cache > 0 @@ -201,9 +233,9 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): model.load_state_dict(checkpoints_loaded[checkpoint_info]) else: # load from file - print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}") + print(f"Loading weights [{sd_model_hash}] from {checkpoint_info.filename}") - sd = read_state_dict(checkpoint_file) + sd = read_state_dict(checkpoint_info.filename) model.load_state_dict(sd, strict=False) del sd @@ -235,14 +267,14 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): checkpoints_loaded.popitem(last=False) # LRU model.sd_model_hash = sd_model_hash - model.sd_model_checkpoint = checkpoint_file + model.sd_model_checkpoint = checkpoint_info.filename model.sd_checkpoint_info = checkpoint_info model.logvar = model.logvar.to(devices.device) # fix for training sd_vae.delete_base_vae() sd_vae.clear_loaded_vae() - vae_file = sd_vae.resolve_vae(checkpoint_file, vae_file=vae_file) + vae_file = sd_vae.resolve_vae(checkpoint_info.filename, vae_file=vae_file) sd_vae.load_vae(model, vae_file) diff --git a/modules/shared.py b/modules/shared.py index b90ded52..d74c069d 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -428,7 +428,7 @@ options_templates.update(options_section(('ui', "User interface"), { "return_grid": OptionInfo(True, "Show grid in results for web"), "do_not_show_images": OptionInfo(False, "Do not show any images in results for web"), "add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"), - "add_model_name_to_info": OptionInfo(False, "Add model name to generation information"), + "add_model_name_to_info": OptionInfo(True, "Add model name to generation information"), "disable_weights_auto_swap": OptionInfo(False, "When reading generation parameters from text into UI (from PNG info or pasted text), do not change the selected model/checkpoint."), "send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"), "send_size": OptionInfo(True, "Send size when sending prompt or image to another interface"), diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 6939efcc..63935878 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -407,7 +407,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method, varsize=varsize) if shared.opts.save_training_settings_to_txt: - save_settings_to_file(log_directory, {**dict(model_name=checkpoint.model_name, model_hash=checkpoint.hash, num_of_dataset_images=len(ds), num_vectors_per_token=len(embedding.vec)), **locals()}) + save_settings_to_file(log_directory, {**dict(model_name=checkpoint.model_name, model_hash=checkpoint.shorthash, num_of_dataset_images=len(ds), num_vectors_per_token=len(embedding.vec)), **locals()}) latent_sampling_method = ds.latent_sampling_method @@ -584,7 +584,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ checkpoint = sd_models.select_checkpoint() footer_left = checkpoint.model_name - footer_mid = '[{}]'.format(checkpoint.hash) + footer_mid = '[{}]'.format(checkpoint.shorthash) footer_right = '{}v {}s'.format(vectorSize, steps_done) captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right) @@ -626,7 +626,7 @@ def save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, r old_sd_checkpoint_name = embedding.sd_checkpoint_name if hasattr(embedding, "sd_checkpoint_name") else None old_cached_checksum = embedding.cached_checksum if hasattr(embedding, "cached_checksum") else None try: - embedding.sd_checkpoint = checkpoint.hash + embedding.sd_checkpoint = checkpoint.shorthash embedding.sd_checkpoint_name = checkpoint.model_name if remove_cached_checksum: embedding.cached_checksum = None diff --git a/webui.py b/webui.py index 47d372c7..1fff80da 100644 --- a/webui.py +++ b/webui.py @@ -78,6 +78,8 @@ def initialize(): print("Stable diffusion model failed to load, exiting", file=sys.stderr) exit(1) + shared.opts.data["sd_model_checkpoint"] = shared.sd_model.sd_checkpoint_info.title + shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights())) shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) From f9ac3352cb66ce2bc0aa4325130fc7267fb35e4f Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 14 Jan 2023 10:25:21 +0300 Subject: [PATCH 32/38] change hypernets to use sha256 hashes --- modules/hypernetworks/hypernetwork.py | 40 +++++++++++++++------------ modules/processing.py | 2 +- modules/sd_models.py | 2 +- modules/shared.py | 1 + 4 files changed, 26 insertions(+), 19 deletions(-) diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 9b5f2e79..3aebefa8 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -12,7 +12,7 @@ import torch import tqdm from einops import rearrange, repeat from ldm.util import default -from modules import devices, processing, sd_models, shared, sd_samplers +from modules import devices, processing, sd_models, shared, sd_samplers, hashes from modules.textual_inversion import textual_inversion, logging from modules.textual_inversion.learn_schedule import LearnRateScheduler from torch import einsum @@ -225,7 +225,7 @@ class Hypernetwork: torch.save(state_dict, filename) if shared.opts.save_optimizer_state and self.optimizer_state_dict: - optimizer_saved_dict['hash'] = sd_models.model_hash(filename) + optimizer_saved_dict['hash'] = self.shorthash() optimizer_saved_dict['optimizer_state_dict'] = self.optimizer_state_dict torch.save(optimizer_saved_dict, filename + '.optim') @@ -237,32 +237,33 @@ class Hypernetwork: state_dict = torch.load(filename, map_location='cpu') self.layer_structure = state_dict.get('layer_structure', [1, 2, 1]) - print(self.layer_structure) - optional_info = state_dict.get('optional_info', None) - if optional_info is not None: - print(f"INFO:\n {optional_info}\n") - self.optional_info = optional_info + self.optional_info = state_dict.get('optional_info', None) self.activation_func = state_dict.get('activation_func', None) - print(f"Activation function is {self.activation_func}") self.weight_init = state_dict.get('weight_initialization', 'Normal') - print(f"Weight initialization is {self.weight_init}") self.add_layer_norm = state_dict.get('is_layer_norm', False) - print(f"Layer norm is set to {self.add_layer_norm}") self.dropout_structure = state_dict.get('dropout_structure', None) self.use_dropout = True if self.dropout_structure is not None and any(self.dropout_structure) else state_dict.get('use_dropout', False) - print(f"Dropout usage is set to {self.use_dropout}" ) self.activate_output = state_dict.get('activate_output', True) - print(f"Activate last layer is set to {self.activate_output}") self.last_layer_dropout = state_dict.get('last_layer_dropout', False) # Dropout structure should have same length as layer structure, Every digits should be in [0,1), and last digit must be 0. if self.dropout_structure is None: - print("Using previous dropout structure") self.dropout_structure = parse_dropout_structure(self.layer_structure, self.use_dropout, self.last_layer_dropout) - print(f"Dropout structure is set to {self.dropout_structure}") - optimizer_saved_dict = torch.load(self.filename + '.optim', map_location = 'cpu') if os.path.exists(self.filename + '.optim') else {} + if shared.opts.print_hypernet_extra: + if self.optional_info is not None: + print(f" INFO:\n {self.optional_info}\n") - if sd_models.model_hash(filename) == optimizer_saved_dict.get('hash', None): + print(f" Layer structure: {self.layer_structure}") + print(f" Activation function: {self.activation_func}") + print(f" Weight initialization: {self.weight_init}") + print(f" Layer norm: {self.add_layer_norm}") + print(f" Dropout usage: {self.use_dropout}" ) + print(f" Activate last layer: {self.activate_output}") + print(f" Dropout structure: {self.dropout_structure}") + + optimizer_saved_dict = torch.load(self.filename + '.optim', map_location='cpu') if os.path.exists(self.filename + '.optim') else {} + + if self.shorthash() == optimizer_saved_dict.get('hash', None): self.optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None) else: self.optimizer_state_dict = None @@ -289,6 +290,11 @@ class Hypernetwork: self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None) self.eval() + def shorthash(self): + sha256 = hashes.sha256(self.filename, f'hypernet/{self.name}') + + return sha256[0:10] + def list_hypernetworks(path): res = {} @@ -296,7 +302,7 @@ def list_hypernetworks(path): name = os.path.splitext(os.path.basename(filename))[0] # Prevent a hypothetical "None.pt" from being listed. if name != "None": - res[name + f"({sd_models.model_hash(filename)})"] = filename + res[name] = filename return res diff --git a/modules/processing.py b/modules/processing.py index ae04cab7..849f6b19 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -437,7 +437,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter "Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash), "Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')), "Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name), - "Hypernet hash": (None if shared.loaded_hypernetwork is None else sd_models.model_hash(shared.loaded_hypernetwork.filename)), + "Hypernet hash": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.shorthash()), "Hypernet strength": (None if shared.loaded_hypernetwork is None or shared.opts.sd_hypernetwork_strength >= 1 else shared.opts.sd_hypernetwork_strength), "Batch size": (None if p.batch_size < 2 else p.batch_size), "Batch pos": (None if p.batch_size < 2 else position_in_batch), diff --git a/modules/sd_models.py b/modules/sd_models.py index 7babb9ae..8f00191c 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -125,7 +125,7 @@ def list_models(): def get_closet_checkpoint_match(search_string): checkpoint_info = checkpoint_alisases.get(search_string, None) if checkpoint_info is not None: - return + return checkpoint_info found = sorted([info for info in checkpoints_list.values() if search_string in info.title], key=lambda x: len(x.title)) if found: diff --git a/modules/shared.py b/modules/shared.py index d74c069d..a6c61db3 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -361,6 +361,7 @@ options_templates.update(options_section(('system', "System"), { "memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation. Set to 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}), "samples_log_stdout": OptionInfo(False, "Always print all generation info to standard output"), "multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."), + "print_hypernet_extra": OptionInfo(False, "Print extra hypernetwork information to console."), })) options_templates.update(options_section(('training', "Training"), { From febd2b722e80959b89a0e5966a159b4eb430c5a5 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 14 Jan 2023 13:37:55 +0300 Subject: [PATCH 33/38] update key to use with checkpoints' sha256 in cache --- modules/sd_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index 8f00191c..1fe6d11b 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -54,7 +54,7 @@ class CheckpointInfo: checkpoint_alisases[id] = self def calculate_shorthash(self): - self.sha256 = hashes.sha256(self.filename, self.title) + self.sha256 = hashes.sha256(self.filename, "checkpoint/" + self.title) self.shorthash = self.sha256[0:10] if self.shorthash not in self.ids: From 08c6f009a5ee92dd3218a942c08e8337c26352be Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 14 Jan 2023 15:55:40 +0300 Subject: [PATCH 34/38] load hashes from cache for checkpoints that have them add checkpoint hash to footer --- javascript/ui.js | 25 ++++++++++++++++--------- modules/hashes.py | 26 +++++++++++++++++++------- modules/sd_models.py | 9 ++++++--- modules/shared.py | 1 + modules/ui.py | 2 ++ script.js | 4 ++++ 6 files changed, 48 insertions(+), 19 deletions(-) diff --git a/javascript/ui.js b/javascript/ui.js index a41dd26f..1e04a8f4 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -143,14 +143,6 @@ function confirm_clear_prompt(prompt, negative_prompt) { opts = {} -function apply_settings(jsdata){ - console.log(jsdata) - - opts = JSON.parse(jsdata) - - return jsdata -} - onUiUpdate(function(){ if(Object.keys(opts).length != 0) return; @@ -160,7 +152,7 @@ onUiUpdate(function(){ textarea = json_elem.querySelector('textarea') jsdata = textarea.value opts = JSON.parse(jsdata) - + executeCallbacks(optionsChangedCallbacks); Object.defineProperty(textarea, 'value', { set: function(newValue) { @@ -171,6 +163,8 @@ onUiUpdate(function(){ if (oldValue != newValue) { opts = JSON.parse(textarea.value) } + + executeCallbacks(optionsChangedCallbacks); }, get: function() { var valueProp = Object.getOwnPropertyDescriptor(HTMLTextAreaElement.prototype, 'value'); @@ -201,6 +195,19 @@ onUiUpdate(function(){ } }) + +onOptionsChanged(function(){ + elem = gradioApp().getElementById('sd_checkpoint_hash') + sd_checkpoint_hash = opts.sd_checkpoint_hash || "" + shorthash = sd_checkpoint_hash.substr(0,10) + + if(elem && elem.textContent != shorthash){ + elem.textContent = shorthash + elem.title = sd_checkpoint_hash + elem.href = "https://google.com/search?q=" + sd_checkpoint_hash + } +}) + let txt2img_textarea, img2img_textarea = undefined; let wait_time = 800 let token_timeout; diff --git a/modules/hashes.py b/modules/hashes.py index ebfbd90c..14231771 100644 --- a/modules/hashes.py +++ b/modules/hashes.py @@ -42,23 +42,35 @@ def calculate_sha256(filename): return hash_sha256.hexdigest() -def sha256(filename, title): +def sha256_from_cache(filename, title): hashes = cache("hashes") ondisk_mtime = os.path.getmtime(filename) - if title in hashes: - cached_sha256 = hashes[title].get("sha256", None) - cached_mtime = hashes[title].get("mtime", 0) + if title not in hashes: + return None - if ondisk_mtime <= cached_mtime and cached_sha256 is not None: - return cached_sha256 + cached_sha256 = hashes[title].get("sha256", None) + cached_mtime = hashes[title].get("mtime", 0) + + if ondisk_mtime > cached_mtime or cached_sha256 is None: + return None + + return cached_sha256 + + +def sha256(filename, title): + hashes = cache("hashes") + + sha256_value = sha256_from_cache(filename, title) + if sha256_value is not None: + return sha256_value print(f"Calculating sha256 for {filename}: ", end='') sha256_value = calculate_sha256(filename) print(f"{sha256_value}") hashes[title] = { - "mtime": ondisk_mtime, + "mtime": os.path.getmtime(filename), "sha256": sha256_value, } diff --git a/modules/sd_models.py b/modules/sd_models.py index 1fe6d11b..e5a0bc63 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -44,9 +44,11 @@ class CheckpointInfo: self.title = name self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0] self.hash = model_hash(filename) - self.ids = [self.hash, self.model_name, self.title, f'{name} [{self.hash}]'] - self.shorthash = None - self.sha256 = None + + self.sha256 = hashes.sha256_from_cache(self.filename, "checkpoint/" + self.title) + self.shorthash = self.sha256[0:10] if self.sha256 else None + + self.ids = [self.hash, self.model_name, self.title, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256] if self.shorthash else []) def register(self): checkpoints_list[self.title] = self @@ -269,6 +271,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, vae_file="auto"): model.sd_model_hash = sd_model_hash model.sd_model_checkpoint = checkpoint_info.filename model.sd_checkpoint_info = checkpoint_info + shared.opts.data["sd_checkpoint_hash"] = checkpoint_info.sha256 model.logvar = model.logvar.to(devices.device) # fix for training diff --git a/modules/shared.py b/modules/shared.py index a6c61db3..c9988d4d 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -458,6 +458,7 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters" options_templates.update(options_section((None, "Hidden options"), { "disabled_extensions": OptionInfo([], "Disable those extensions"), + "sd_checkpoint_hash": OptionInfo("", "SHA256 hash of the current checkpoint"), })) options_templates.update() diff --git a/modules/ui.py b/modules/ui.py index e86a624b..2625ae32 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1841,4 +1841,6 @@ xformers: {xformers_version} gradio: {gr.__version__}  •  commit: {short_commit} + •  +checkpoint: N/A """ diff --git a/script.js b/script.js index 21960d91..3345e32b 100644 --- a/script.js +++ b/script.js @@ -14,6 +14,7 @@ function get_uiCurrentTabContent() { uiUpdateCallbacks = [] uiTabChangeCallbacks = [] +optionsChangedCallbacks = [] let uiCurrentTab = null function onUiUpdate(callback){ @@ -22,6 +23,9 @@ function onUiUpdate(callback){ function onUiTabChange(callback){ uiTabChangeCallbacks.push(callback) } +function onOptionsChanged(callback){ + optionsChangedCallbacks.push(callback) +} function runCallback(x, m){ try { From f94a215abed85b34ae978853078812801d3e7738 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 14 Jan 2023 16:29:23 +0300 Subject: [PATCH 35/38] add an option to choose what you want to see in live preview (Live preview subject) and moves live preview settings to its own tab --- modules/sd_samplers.py | 15 ++++++++++----- modules/shared.py | 13 +++++++++---- modules/ui_progress.py | 2 +- 3 files changed, 20 insertions(+), 10 deletions(-) diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index 01221b89..7616fded 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -138,7 +138,7 @@ def samples_to_image_grid(samples, approximation=None): def store_latent(decoded): state.current_latent = decoded - if opts.show_progress_every_n_steps > 0 and shared.state.sampling_step % opts.show_progress_every_n_steps == 0: + if opts.live_previews_enable and opts.show_progress_every_n_steps > 0 and shared.state.sampling_step % opts.show_progress_every_n_steps == 0: if not shared.parallel_processing_allowed: shared.state.current_image = sample_to_image(decoded) @@ -243,7 +243,7 @@ class VanillaStableDiffusionSampler: self.nmask = p.nmask if hasattr(p, 'nmask') else None def adjust_steps_if_invalid(self, p, num_steps): - if (self.config.name == 'DDIM' and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS'): + if (self.config.name == 'DDIM' and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS'): valid_step = 999 / (1000 // num_steps) if valid_step == floor(valid_step): return int(valid_step) + 1 @@ -266,8 +266,7 @@ class VanillaStableDiffusionSampler: if image_conditioning is not None: conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]} unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]} - - + samples = self.launch_sampling(t_enc + 1, lambda: self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning)) return samples @@ -352,6 +351,11 @@ class CFGDenoiser(torch.nn.Module): x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond={"c_crossattn": [uncond], "c_concat": [image_cond_in[-uncond.shape[0]:]]}) + if opts.live_preview_content == "Prompt": + store_latent(x_out[0:uncond.shape[0]]) + elif opts.live_preview_content == "Negative prompt": + store_latent(x_out[-uncond.shape[0]:]) + denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale) if self.mask is not None: @@ -423,7 +427,8 @@ class KDiffusionSampler: def callback_state(self, d): step = d['i'] latent = d["denoised"] - store_latent(latent) + if opts.live_preview_content == "Combined": + store_latent(latent) self.last_latent = latent if self.stop_at is not None and step > self.stop_at: diff --git a/modules/shared.py b/modules/shared.py index c9988d4d..e0ec3136 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -176,7 +176,7 @@ class State: self.interrupted = True def nextjob(self): - if opts.show_progress_every_n_steps == -1: + if opts.live_previews_enable and opts.show_progress_every_n_steps == -1: self.do_set_current_image() self.job_no += 1 @@ -224,7 +224,7 @@ class State: if not parallel_processing_allowed: return - if self.sampling_step - self.current_image_sampling_step >= opts.show_progress_every_n_steps and opts.show_progress_every_n_steps > 0: + if self.sampling_step - self.current_image_sampling_step >= opts.show_progress_every_n_steps and opts.live_previews_enable: self.do_set_current_image() def do_set_current_image(self): @@ -423,8 +423,6 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"), options_templates.update(options_section(('ui', "User interface"), { "show_progressbar": OptionInfo(True, "Show progressbar"), - "show_progress_every_n_steps": OptionInfo(0, "Show image creation progress every N sampling steps. Set to 0 to disable. Set to -1 to show after completion of batch.", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}), - "show_progress_type": OptionInfo("Full", "Image creation progress preview mode", gr.Radio, {"choices": ["Full", "Approx NN", "Approx cheap"]}), "show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"), "return_grid": OptionInfo(True, "Show grid in results for web"), "do_not_show_images": OptionInfo(False, "Do not show any images in results for web"), @@ -444,6 +442,13 @@ options_templates.update(options_section(('ui', "User interface"), { 'localization': OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)), })) +options_templates.update(options_section(('ui', "Live previews"), { + "live_previews_enable": OptionInfo(True, "Show live previews of the created image"), + "show_progress_every_n_steps": OptionInfo(10, "Show new live preview image every N sampling steps. Set to -1 to show after completion of batch.", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}), + "show_progress_type": OptionInfo("Approx NN", "Image creation progress preview mode", gr.Radio, {"choices": ["Full", "Approx NN", "Approx cheap"]}), + "live_preview_content": OptionInfo("Prompt", "Live preview subject", gr.Radio, {"choices": ["Combined", "Prompt", "Negative prompt"]}), +})) + options_templates.update(options_section(('sampler-params', "Sampler parameters"), { "hide_samplers": OptionInfo([], "Hide samplers in user interface (requires restart)", gr.CheckboxGroup, lambda: {"choices": [x.name for x in list_samplers()]}), "eta_ddim": OptionInfo(0.0, "eta (noise multiplier) for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), diff --git a/modules/ui_progress.py b/modules/ui_progress.py index 592fda55..7cd312e4 100644 --- a/modules/ui_progress.py +++ b/modules/ui_progress.py @@ -52,7 +52,7 @@ def check_progress_call(id_part): image = gr.update(visible=False) preview_visibility = gr.update(visible=False) - if opts.show_progress_every_n_steps != 0: + if opts.live_previews_enable: shared.state.set_current_image() image = shared.state.current_image From 69781031e7473e020b3af4461fdceb20130e56ab Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 14 Jan 2023 16:45:39 +0300 Subject: [PATCH 36/38] simplify expression in prompts from file script --- scripts/prompts_from_file.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/scripts/prompts_from_file.py b/scripts/prompts_from_file.py index 1fe10a7c..f3e711d7 100644 --- a/scripts/prompts_from_file.py +++ b/scripts/prompts_from_file.py @@ -146,11 +146,7 @@ class Script(scripts.Script): else: args = {"prompt": line} - n_iter = args.get("n_iter", p.n_iter) - if n_iter != 1: - job_count += n_iter - else: - job_count += 1 + job_count += args.get("n_iter", p.n_iter) jobs.append(args) From a5bbcd215304e0c83ab2b9fe7f172f88536d7629 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 14 Jan 2023 19:56:09 +0300 Subject: [PATCH 37/38] fix bug with "Ignore selected VAE for..." option completely disabling VAE election rework VAE resolving code to be more simple --- modules/sd_models.py | 6 +- modules/sd_vae.py | 174 +++++++++++++++++-------------------------- modules/shared.py | 4 +- scripts/xy_grid.py | 27 ++++--- 4 files changed, 85 insertions(+), 126 deletions(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index e5a0bc63..6a681cef 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -224,7 +224,7 @@ def read_state_dict(checkpoint_file, print_global_state=False, map_location=None return sd -def load_model_weights(model, checkpoint_info: CheckpointInfo, vae_file="auto"): +def load_model_weights(model, checkpoint_info: CheckpointInfo): sd_model_hash = checkpoint_info.calculate_shorthash() cache_enabled = shared.opts.sd_checkpoint_cache > 0 @@ -277,8 +277,8 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, vae_file="auto"): sd_vae.delete_base_vae() sd_vae.clear_loaded_vae() - vae_file = sd_vae.resolve_vae(checkpoint_info.filename, vae_file=vae_file) - sd_vae.load_vae(model, vae_file) + vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename) + sd_vae.load_vae(model, vae_file, vae_source) def enable_midas_autodownload(): diff --git a/modules/sd_vae.py b/modules/sd_vae.py index 0a49daa1..6ea92711 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -9,23 +9,9 @@ import glob from copy import deepcopy -model_dir = "Stable-diffusion" -model_path = os.path.abspath(os.path.join(models_path, model_dir)) -vae_dir = "VAE" -vae_path = os.path.abspath(os.path.join(models_path, vae_dir)) - - +vae_path = os.path.abspath(os.path.join(models_path, "VAE")) vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"} - - -default_vae_dict = {"auto": "auto", "None": None, None: None} -default_vae_list = ["auto", "None"] - - -default_vae_values = [default_vae_dict[x] for x in default_vae_list] -vae_dict = dict(default_vae_dict) -vae_list = list(default_vae_list) -first_load = True +vae_dict = {} base_vae = None @@ -64,100 +50,69 @@ def restore_base_vae(model): def get_filename(filepath): - return os.path.splitext(os.path.basename(filepath))[0] + return os.path.basename(filepath) -def refresh_vae_list(vae_path=vae_path, model_path=model_path): - global vae_dict, vae_list - res = {} - candidates = [ - *glob.iglob(os.path.join(model_path, '**/*.vae.ckpt'), recursive=True), - *glob.iglob(os.path.join(model_path, '**/*.vae.pt'), recursive=True), - *glob.iglob(os.path.join(model_path, '**/*.vae.safetensors'), recursive=True), - *glob.iglob(os.path.join(vae_path, '**/*.ckpt'), recursive=True), - *glob.iglob(os.path.join(vae_path, '**/*.pt'), recursive=True), - *glob.iglob(os.path.join(vae_path, '**/*.safetensors'), recursive=True), +def refresh_vae_list(): + vae_dict.clear() + + paths = [ + os.path.join(sd_models.model_path, '**/*.vae.ckpt'), + os.path.join(sd_models.model_path, '**/*.vae.pt'), + os.path.join(sd_models.model_path, '**/*.vae.safetensors'), + os.path.join(vae_path, '**/*.ckpt'), + os.path.join(vae_path, '**/*.pt'), + os.path.join(vae_path, '**/*.safetensors'), ] - if shared.cmd_opts.vae_path is not None and os.path.isfile(shared.cmd_opts.vae_path): - candidates.append(shared.cmd_opts.vae_path) + + if shared.cmd_opts.ckpt_dir is not None and os.path.isdir(shared.cmd_opts.ckpt_dir): + paths += [ + os.path.join(shared.cmd_opts.ckpt_dir, '**/*.vae.ckpt'), + os.path.join(shared.cmd_opts.ckpt_dir, '**/*.vae.pt'), + os.path.join(shared.cmd_opts.ckpt_dir, '**/*.vae.safetensors'), + ] + + candidates = [] + for path in paths: + candidates += glob.iglob(path, recursive=True) + for filepath in candidates: name = get_filename(filepath) - res[name] = filepath - vae_list.clear() - vae_list.extend(default_vae_list) - vae_list.extend(list(res.keys())) - vae_dict.clear() - vae_dict.update(res) - vae_dict.update(default_vae_dict) - return vae_list + vae_dict[name] = filepath -def get_vae_from_settings(vae_file="auto"): - # else, we load from settings, if not set to be default - if vae_file == "auto" and shared.opts.sd_vae is not None: - # if saved VAE settings isn't recognized, fallback to auto - vae_file = vae_dict.get(shared.opts.sd_vae, "auto") - # if VAE selected but not found, fallback to auto - if vae_file not in default_vae_values and not os.path.isfile(vae_file): - vae_file = "auto" - print(f"Selected VAE doesn't exist: {vae_file}") - return vae_file +def find_vae_near_checkpoint(checkpoint_file): + checkpoint_path = os.path.splitext(checkpoint_file)[0] + for vae_location in [checkpoint_path + ".vae.pt", checkpoint_path + ".vae.ckpt", checkpoint_path + ".vae.safetensors"]: + if os.path.isfile(vae_location): + return vae_location + + return None -def resolve_vae(checkpoint_file=None, vae_file="auto"): - global first_load, vae_dict, vae_list +def resolve_vae(checkpoint_file): + if shared.cmd_opts.vae_path is not None: + return shared.cmd_opts.vae_path, 'from commandline argument' - # if vae_file argument is provided, it takes priority, but not saved - if vae_file and vae_file not in default_vae_list: - if not os.path.isfile(vae_file): - print(f"VAE provided as function argument doesn't exist: {vae_file}") - vae_file = "auto" - # for the first load, if vae-path is provided, it takes priority, saved, and failure is reported - if first_load and shared.cmd_opts.vae_path is not None: - if os.path.isfile(shared.cmd_opts.vae_path): - vae_file = shared.cmd_opts.vae_path - shared.opts.data['sd_vae'] = get_filename(vae_file) - else: - print(f"VAE provided as command line argument doesn't exist: {vae_file}") - # fallback to selector in settings, if vae selector not set to act as default fallback - if not shared.opts.sd_vae_as_default: - vae_file = get_vae_from_settings(vae_file) - # vae-path cmd arg takes priority for auto - if vae_file == "auto" and shared.cmd_opts.vae_path is not None: - if os.path.isfile(shared.cmd_opts.vae_path): - vae_file = shared.cmd_opts.vae_path - print(f"Using VAE provided as command line argument: {vae_file}") - # if still not found, try look for ".vae.pt" beside model - model_path = os.path.splitext(checkpoint_file)[0] - if vae_file == "auto": - vae_file_try = model_path + ".vae.pt" - if os.path.isfile(vae_file_try): - vae_file = vae_file_try - print(f"Using VAE found similar to selected model: {vae_file}") - # if still not found, try look for ".vae.ckpt" beside model - if vae_file == "auto": - vae_file_try = model_path + ".vae.ckpt" - if os.path.isfile(vae_file_try): - vae_file = vae_file_try - print(f"Using VAE found similar to selected model: {vae_file}") - # if still not found, try look for ".vae.safetensors" beside model - if vae_file == "auto": - vae_file_try = model_path + ".vae.safetensors" - if os.path.isfile(vae_file_try): - vae_file = vae_file_try - print(f"Using VAE found similar to selected model: {vae_file}") - # No more fallbacks for auto - if vae_file == "auto": - vae_file = None - # Last check, just because - if vae_file and not os.path.exists(vae_file): - vae_file = None + vae_near_checkpoint = find_vae_near_checkpoint(checkpoint_file) + if vae_near_checkpoint is not None and (shared.opts.sd_vae_as_default or shared.opts.sd_vae == "auto"): + return vae_near_checkpoint, 'found near the checkpoint' - return vae_file + if shared.opts.sd_vae == "None": + return None, None + + vae_from_options = vae_dict.get(shared.opts.sd_vae, None) + if vae_from_options is not None: + return vae_from_options, 'specified in settings' + + if shared.opts.sd_vae != "Automatic": + print(f"Couldn't find VAE named {shared.opts.sd_vae}; using None instead") + + return None, None -def load_vae(model, vae_file=None): - global first_load, vae_dict, vae_list, loaded_vae_file +def load_vae(model, vae_file=None, vae_source="from unknown source"): + global vae_dict, loaded_vae_file # save_settings = False cache_enabled = shared.opts.sd_vae_checkpoint_cache > 0 @@ -165,12 +120,12 @@ def load_vae(model, vae_file=None): if vae_file: if cache_enabled and vae_file in checkpoints_loaded: # use vae checkpoint cache - print(f"Loading VAE weights [{get_filename(vae_file)}] from cache") + print(f"Loading VAE weights {vae_source}: cached {get_filename(vae_file)}") store_base_vae(model) _load_vae_dict(model, checkpoints_loaded[vae_file]) else: - assert os.path.isfile(vae_file), f"VAE file doesn't exist: {vae_file}" - print(f"Loading VAE weights from: {vae_file}") + assert os.path.isfile(vae_file), f"VAE {vae_source} doesn't exist: {vae_file}" + print(f"Loading VAE weights {vae_source}: {vae_file}") store_base_vae(model) vae_ckpt = sd_models.read_state_dict(vae_file, map_location=shared.weight_load_location) @@ -191,14 +146,12 @@ def load_vae(model, vae_file=None): vae_opt = get_filename(vae_file) if vae_opt not in vae_dict: vae_dict[vae_opt] = vae_file - vae_list.append(vae_opt) + elif loaded_vae_file: restore_base_vae(model) loaded_vae_file = vae_file - first_load = False - # don't call this from outside def _load_vae_dict(model, vae_dict_1): @@ -211,7 +164,10 @@ def clear_loaded_vae(): loaded_vae_file = None -def reload_vae_weights(sd_model=None, vae_file="auto"): +unspecified = object() + + +def reload_vae_weights(sd_model=None, vae_file=unspecified): from modules import lowvram, devices, sd_hijack if not sd_model: @@ -219,7 +175,11 @@ def reload_vae_weights(sd_model=None, vae_file="auto"): checkpoint_info = sd_model.sd_checkpoint_info checkpoint_file = checkpoint_info.filename - vae_file = resolve_vae(checkpoint_file, vae_file=vae_file) + + if vae_file == unspecified: + vae_file, vae_source = resolve_vae(checkpoint_file) + else: + vae_source = "from function argument" if loaded_vae_file == vae_file: return @@ -231,7 +191,7 @@ def reload_vae_weights(sd_model=None, vae_file="auto"): sd_hijack.model_hijack.undo_hijack(sd_model) - load_vae(sd_model, vae_file) + load_vae(sd_model, vae_file, vae_source) sd_hijack.model_hijack.hijack(sd_model) script_callbacks.model_loaded_callback(sd_model) @@ -239,5 +199,5 @@ def reload_vae_weights(sd_model=None, vae_file="auto"): if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram: sd_model.to(devices.device) - print("VAE Weights loaded.") + print("VAE weights loaded.") return sd_model diff --git a/modules/shared.py b/modules/shared.py index e0ec3136..9756adea 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -83,7 +83,7 @@ parser.add_argument("--theme", type=str, help="launches the UI with light or dar parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False) parser.add_argument("--disable-console-progressbars", action='store_true', help="do not output progressbars to console", default=False) parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False) -parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencoders model', default=None) +parser.add_argument('--vae-path', type=str, help='Checkpoint to use as VAE; setting this argument disables all settings related to VAE', default=None) parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False) parser.add_argument("--api", action='store_true', help="use api=True to launch the API together with the webui (use --nowebui instead for only the API)") parser.add_argument("--api-auth", type=str, help='Set authentication for API like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None) @@ -383,7 +383,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints), "sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), "sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), - "sd_vae": OptionInfo("auto", "SD VAE", gr.Dropdown, lambda: {"choices": sd_vae.vae_list}, refresh=sd_vae.refresh_vae_list), + "sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": ["Automatic", "None"] + list(sd_vae.vae_dict)}, refresh=sd_vae.refresh_vae_list), "sd_vae_as_default": OptionInfo(False, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"), "sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks), "sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}), diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index f04d9b7e..bd3087d4 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -125,24 +125,21 @@ def apply_upscale_latent_space(p, x, xs): def find_vae(name: str): - if name.lower() in ['auto', 'none']: - return name + if name.lower() in ['auto', 'automatic']: + return modules.sd_vae.unspecified + if name.lower() == 'none': + return None else: - vae_path = os.path.abspath(os.path.join(paths.models_path, 'VAE')) - found = glob.glob(os.path.join(vae_path, f'**/{name}.*pt'), recursive=True) - if found: - return found[0] + choices = [x for x in sorted(modules.sd_vae.vae_dict, key=lambda x: len(x)) if name.lower().strip() in x.lower()] + if len(choices) == 0: + print(f"No VAE found for {name}; using automatic") + return modules.sd_vae.unspecified else: - return 'auto' + return modules.sd_vae.vae_dict[choices[0]] def apply_vae(p, x, xs): - if x.lower().strip() == 'none': - modules.sd_vae.reload_vae_weights(shared.sd_model, vae_file='None') - else: - found = find_vae(x) - if found: - v = modules.sd_vae.reload_vae_weights(shared.sd_model, vae_file=found) + modules.sd_vae.reload_vae_weights(shared.sd_model, vae_file=find_vae(x)) def apply_styles(p: StableDiffusionProcessingTxt2Img, x: str, _): @@ -271,7 +268,9 @@ class SharedSettingsStackHelper(object): def __exit__(self, exc_type, exc_value, tb): modules.sd_models.reload_model_weights(self.model) - modules.sd_vae.reload_vae_weights(self.model, vae_file=find_vae(self.vae)) + + opts.data["sd_vae"] = self.vae + modules.sd_vae.reload_vae_weights(self.model) hypernetwork.load_hypernetwork(self.hypernetwork) hypernetwork.apply_strength() From f8c512478568293155539f616dce26c5e4495055 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 14 Jan 2023 20:00:12 +0300 Subject: [PATCH 38/38] typo? --- modules/sd_vae.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/sd_vae.py b/modules/sd_vae.py index 6ea92711..add5cecf 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -95,7 +95,7 @@ def resolve_vae(checkpoint_file): return shared.cmd_opts.vae_path, 'from commandline argument' vae_near_checkpoint = find_vae_near_checkpoint(checkpoint_file) - if vae_near_checkpoint is not None and (shared.opts.sd_vae_as_default or shared.opts.sd_vae == "auto"): + if vae_near_checkpoint is not None and (shared.opts.sd_vae_as_default or shared.opts.sd_vae == "Automatic"): return vae_near_checkpoint, 'found near the checkpoint' if shared.opts.sd_vae == "None":