From 3682ae2e6624130f9e9bba09cb49febab2f94624 Mon Sep 17 00:00:00 2001 From: Liam Date: Wed, 28 Dec 2022 09:03:18 -0500 Subject: [PATCH] added VAE and VAE hash to image generation params; 'Send to' buttons in PNG Info now load VAE from image's generation params --- modules/generation_parameters_copypaste.py | 8 +++- modules/processing.py | 2 + modules/sd_vae.py | 46 ++++++++++++++++------ modules/shared.py | 2 + modules/ui.py | 4 ++ 5 files changed, 48 insertions(+), 14 deletions(-) diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index 12a9de3d..b69f19e1 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -7,7 +7,7 @@ from pathlib import Path import gradio as gr from modules.shared import script_path -from modules import shared, ui_tempdir +from modules import shared, ui_tempdir, sd_vae import tempfile from PIL import Image @@ -83,6 +83,7 @@ def integrate_settings_paste_fields(component_dict): 'sd_model_checkpoint': 'Model hash', 'eta_noise_seed_delta': 'ENSD', 'initial_noise_multiplier': 'Noise multiplier', + 'sd_vae': 'VAE', } settings_paste_fields = [ (component_dict[k], lambda d, k=k, v=v: ui.apply_setting(k, d.get(v, None))) @@ -281,6 +282,11 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model restore_old_hires_fix_params(res) + if "VAE" in res: + vae_name = res["VAE"] + vae_hash = res.get("VAE hash", None) + res["VAE"] = sd_vae.find_vae_key(vae_name, vae_hash) + return res diff --git a/modules/processing.py b/modules/processing.py index 1d23b15f..04e8de16 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -449,6 +449,8 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter "Eta": (None if p.sampler is None or p.sampler.eta == p.sampler.default_eta else p.sampler.eta), "Clip skip": None if clip_skip <= 1 else clip_skip, "ENSD": None if opts.eta_noise_seed_delta == 0 else opts.eta_noise_seed_delta, + "VAE": "None" if shared.loaded_vae_file is None else os.path.split(shared.loaded_vae_file)[1].removesuffix(".pt"), + "VAE hash": None if shared.loaded_vae_file is None else sd_models.model_hash(shared.loaded_vae_file), } generation_params.update(p.extra_generation_params) diff --git a/modules/sd_vae.py b/modules/sd_vae.py index ac71d62d..a680122d 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -2,10 +2,11 @@ import torch import os import collections from collections import namedtuple -from modules import shared, devices, script_callbacks +from modules import shared, devices, script_callbacks, sd_models from modules.paths import models_path import glob from copy import deepcopy +from collections import defaultdict model_dir = "Stable-diffusion" @@ -24,11 +25,11 @@ default_vae_list = ["auto", "None"] default_vae_values = [default_vae_dict[x] for x in default_vae_list] vae_dict = dict(default_vae_dict) vae_list = list(default_vae_list) +vae_hash_to_filename = defaultdict(list) first_load = True base_vae = None -loaded_vae_file = None checkpoint_info = None checkpoints_loaded = collections.OrderedDict() @@ -42,7 +43,7 @@ def get_base_vae(model): def store_base_vae(model): global base_vae, checkpoint_info if checkpoint_info != model.sd_checkpoint_info: - assert not loaded_vae_file, "Trying to store non-base VAE!" + assert not shared.loaded_vae_file, "Trying to store non-base VAE!" base_vae = deepcopy(model.first_stage_model.state_dict()) checkpoint_info = model.sd_checkpoint_info @@ -54,11 +55,11 @@ def delete_base_vae(): def restore_base_vae(model): - global loaded_vae_file if base_vae is not None and checkpoint_info == model.sd_checkpoint_info: print("Restoring base VAE") _load_vae_dict(model, base_vae) - loaded_vae_file = None + shared.loaded_vae_file = None + shared.opts.sd_vae = "None" delete_base_vae() @@ -67,7 +68,7 @@ def get_filename(filepath): def refresh_vae_list(vae_path=vae_path, model_path=model_path): - global vae_dict, vae_list + global vae_dict, vae_list, vae_hash_to_filename res = {} candidates = [ *glob.iglob(os.path.join(model_path, '**/*.vae.ckpt'), recursive=True), @@ -77,9 +78,11 @@ def refresh_vae_list(vae_path=vae_path, model_path=model_path): ] if shared.cmd_opts.vae_path is not None and os.path.isfile(shared.cmd_opts.vae_path): candidates.append(shared.cmd_opts.vae_path) + vae_hash_to_filename.clear() for filepath in candidates: name = get_filename(filepath) res[name] = filepath + vae_hash_to_filename[sd_models.model_hash(filepath)].append(name) vae_list.clear() vae_list.extend(default_vae_list) vae_list.extend(list(res.keys())) @@ -148,7 +151,7 @@ def resolve_vae(checkpoint_file=None, vae_file="auto"): def load_vae(model, vae_file=None): - global first_load, vae_dict, vae_list, loaded_vae_file + global first_load, vae_dict, vae_list # save_settings = False cache_enabled = shared.opts.sd_vae_checkpoint_cache > 0 @@ -182,11 +185,10 @@ def load_vae(model, vae_file=None): if vae_opt not in vae_dict: vae_dict[vae_opt] = vae_file vae_list.append(vae_opt) - elif loaded_vae_file: + shared.loaded_vae_file = vae_file + elif shared.loaded_vae_file: restore_base_vae(model) - loaded_vae_file = vae_file - first_load = False @@ -196,8 +198,7 @@ def _load_vae_dict(model, vae_dict_1): model.first_stage_model.to(devices.dtype_vae) def clear_loaded_vae(): - global loaded_vae_file - loaded_vae_file = None + shared.loaded_vae_file = None def reload_vae_weights(sd_model=None, vae_file="auto"): from modules import lowvram, devices, sd_hijack @@ -209,7 +210,7 @@ def reload_vae_weights(sd_model=None, vae_file="auto"): checkpoint_file = checkpoint_info.filename vae_file = resolve_vae(checkpoint_file, vae_file=vae_file) - if loaded_vae_file == vae_file: + if shared.loaded_vae_file == vae_file: return if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: @@ -229,3 +230,22 @@ def reload_vae_weights(sd_model=None, vae_file="auto"): print("VAE Weights loaded.") return sd_model + + +def is_valid_vae(vae_file: str): + return vae_file in vae_dict + + +def find_vae_key(vae_name, vae_hash=None): + """Determines the config parameter name to use for the VAE based on the parameters in the infotext. + If vae_hash is provided, this function will return the name of any local VAE file that matches the hash. + If vae_hash is None, this function will return vae_name if any local VAE files are named vae_name + """ + if vae_hash is not None and (matched := vae_hash_to_filename.get(vae_hash)): + if vae_name in matched or vae_name.lower() in matched: + return vae_name + return matched[0] + else: + if vae_name.lower() in [vae_filename.lower() for vae_filename in vae_list]: + return vae_name + return None diff --git a/modules/shared.py b/modules/shared.py index a6712dae..f4fb7c54 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -142,6 +142,8 @@ os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True) hypernetworks = {} loaded_hypernetwork = None +loaded_vae_file = None + def reload_hypernetworks(): from modules.hypernetworks import hypernetwork diff --git a/modules/ui.py b/modules/ui.py index 99483130..d81307cf 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -40,6 +40,7 @@ from modules.sd_samplers import samplers, samplers_for_img2img import modules.textual_inversion.ui import modules.hypernetworks.ui from modules.generation_parameters_copypaste import image_from_url_text +from modules.sd_vae import is_valid_vae # this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI mimetypes.init() @@ -495,6 +496,9 @@ def apply_setting(key, value): value = ckpt_info.title else: return gr.update() + if key == 'sd_vae' and not is_valid_vae(value): + # ignore invalid vaes + return gr.update() comp_args = opts.data_labels[key].component_args if comp_args and isinstance(comp_args, dict) and comp_args.get('visible') is False: