diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index fc9e17aa..ec3cc52e 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -8,7 +8,7 @@ from pathlib import Path import gradio as gr from modules.paths import data_path -from modules import shared, ui_tempdir, script_callbacks +from modules import shared, ui_tempdir, script_callbacks, sd_vae import tempfile from PIL import Image @@ -282,6 +282,11 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model restore_old_hires_fix_params(res) + if "VAE" in res: + vae_name = res["VAE"] + vae_hash = res.get("VAE hash", None) + res["VAE"] = sd_vae.find_vae_key(vae_name, vae_hash) + return res @@ -295,7 +300,8 @@ infotext_to_setting_name_mapping = [ ('Noise multiplier', 'initial_noise_multiplier'), ('Eta', 'eta_ancestral'), ('Eta DDIM', 'eta_ddim'), - ('Discard penultimate sigma', 'always_discard_next_to_last_sigma') + ('Discard penultimate sigma', 'always_discard_next_to_last_sigma'), + ('VAE', 'sd_vae'), ] diff --git a/modules/processing.py b/modules/processing.py index e1b53ac0..9f25c691 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -459,6 +459,8 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter "Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None, "Clip skip": None if clip_skip <= 1 else clip_skip, "ENSD": None if opts.eta_noise_seed_delta == 0 else opts.eta_noise_seed_delta, + "VAE": "None" if sd_vae.loaded_vae_file is None else os.path.split(sd_vae.loaded_vae_file)[1].removesuffix(".pt"), + "VAE hash": None if sd_vae.loaded_vae_file is None else sd_models.model_hash(sd_vae.loaded_vae_file), } generation_params.update(p.extra_generation_params) diff --git a/modules/sd_vae.py b/modules/sd_vae.py index 9b00f76e..b69ac795 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -6,6 +6,7 @@ from collections import namedtuple from modules import paths, shared, devices, script_callbacks, sd_models import glob from copy import deepcopy +from collections import defaultdict vae_path = os.path.abspath(os.path.join(paths.models_path, "VAE")) @@ -17,6 +18,8 @@ base_vae = None loaded_vae_file = None checkpoint_info = None +vae_hash_to_filename = defaultdict(list) + checkpoints_loaded = collections.OrderedDict() def get_base_vae(model): @@ -82,9 +85,11 @@ def refresh_vae_list(): for path in paths: candidates += glob.iglob(path, recursive=True) + vae_hash_to_filename.clear() for filepath in candidates: name = get_filename(filepath) vae_dict[name] = filepath + vae_hash_to_filename[sd_models.model_hash(filepath)].append(name) def find_vae_near_checkpoint(checkpoint_file): @@ -159,6 +164,7 @@ def load_vae(model, vae_file=None, vae_source="from unknown source"): vae_opt = get_filename(vae_file) if vae_opt not in vae_dict: vae_dict[vae_opt] = vae_file + vae_hash_to_filename[sd_models.model_hash(vae_file)] = vae_opt elif loaded_vae_file: restore_base_vae(model) @@ -214,3 +220,28 @@ def reload_vae_weights(sd_model=None, vae_file=unspecified): print("VAE weights loaded.") return sd_model + + +def is_valid_vae(vae_file: str): + """ + Returns true if the vae_file name exists in the cache of vae files + A vae_file of "None" is valid because it represents the "None" option in the vae_se + """ + return vae_file in vae_dict or vae_file == "None" + + +def find_vae_key(vae_name, vae_hash=None): + """Determines the config parameter name to use for the VAE based on the parameters in the infotext. + If vae_hash is provided, this function will return the name of any local VAE file that matches the hash. + If vae_hash is None, this function will return vae_name if any local VAE files are named vae_name + """ + if vae_name == "None": + return vae_name + if vae_hash is not None and (matched := vae_hash_to_filename.get(vae_hash)): + if vae_name in matched or vae_name.lower() in matched: + return vae_name + return matched[0] + else: + if vae_name.lower() in [vae_filename.lower() for vae_filename in vae_dict.keys()]: + return vae_name + return None diff --git a/modules/ui.py b/modules/ui.py index f5df1ffe..2d1e4d79 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -40,6 +40,7 @@ from modules.sd_samplers import samplers, samplers_for_img2img from modules.textual_inversion import textual_inversion import modules.hypernetworks.ui from modules.generation_parameters_copypaste import image_from_url_text +from modules.sd_vae import is_valid_vae import modules.extras warnings.filterwarnings("default" if opts.show_warnings else "ignore", category=UserWarning) @@ -366,6 +367,9 @@ def apply_setting(key, value): value = ckpt_info.title else: return gr.update() + if key == 'sd_vae' and not is_valid_vae(value): + # ignore invalid vaes + return gr.update() comp_args = opts.data_labels[key].component_args if comp_args and isinstance(comp_args, dict) and comp_args.get('visible') is False: