This commit is contained in:
liamkerr 2023-02-05 06:51:25 -05:00 committed by GitHub
commit 3ac55dffab
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 45 additions and 2 deletions

View File

@ -8,7 +8,7 @@ from pathlib import Path
import gradio as gr
from modules.paths import data_path
from modules import shared, ui_tempdir, script_callbacks
from modules import shared, ui_tempdir, script_callbacks, sd_vae
import tempfile
from PIL import Image
@ -282,6 +282,11 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
restore_old_hires_fix_params(res)
if "VAE" in res:
vae_name = res["VAE"]
vae_hash = res.get("VAE hash", None)
res["VAE"] = sd_vae.find_vae_key(vae_name, vae_hash)
return res
@ -295,7 +300,8 @@ infotext_to_setting_name_mapping = [
('Noise multiplier', 'initial_noise_multiplier'),
('Eta', 'eta_ancestral'),
('Eta DDIM', 'eta_ddim'),
('Discard penultimate sigma', 'always_discard_next_to_last_sigma')
('Discard penultimate sigma', 'always_discard_next_to_last_sigma'),
('VAE', 'sd_vae'),
]

View File

@ -459,6 +459,8 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
"Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None,
"Clip skip": None if clip_skip <= 1 else clip_skip,
"ENSD": None if opts.eta_noise_seed_delta == 0 else opts.eta_noise_seed_delta,
"VAE": "None" if sd_vae.loaded_vae_file is None else os.path.split(sd_vae.loaded_vae_file)[1].removesuffix(".pt"),
"VAE hash": None if sd_vae.loaded_vae_file is None else sd_models.model_hash(sd_vae.loaded_vae_file),
}
generation_params.update(p.extra_generation_params)

View File

@ -6,6 +6,7 @@ from collections import namedtuple
from modules import paths, shared, devices, script_callbacks, sd_models
import glob
from copy import deepcopy
from collections import defaultdict
vae_path = os.path.abspath(os.path.join(paths.models_path, "VAE"))
@ -17,6 +18,8 @@ base_vae = None
loaded_vae_file = None
checkpoint_info = None
vae_hash_to_filename = defaultdict(list)
checkpoints_loaded = collections.OrderedDict()
def get_base_vae(model):
@ -82,9 +85,11 @@ def refresh_vae_list():
for path in paths:
candidates += glob.iglob(path, recursive=True)
vae_hash_to_filename.clear()
for filepath in candidates:
name = get_filename(filepath)
vae_dict[name] = filepath
vae_hash_to_filename[sd_models.model_hash(filepath)].append(name)
def find_vae_near_checkpoint(checkpoint_file):
@ -159,6 +164,7 @@ def load_vae(model, vae_file=None, vae_source="from unknown source"):
vae_opt = get_filename(vae_file)
if vae_opt not in vae_dict:
vae_dict[vae_opt] = vae_file
vae_hash_to_filename[sd_models.model_hash(vae_file)] = vae_opt
elif loaded_vae_file:
restore_base_vae(model)
@ -214,3 +220,28 @@ def reload_vae_weights(sd_model=None, vae_file=unspecified):
print("VAE weights loaded.")
return sd_model
def is_valid_vae(vae_file: str):
"""
Returns true if the vae_file name exists in the cache of vae files
A vae_file of "None" is valid because it represents the "None" option in the vae_se
"""
return vae_file in vae_dict or vae_file == "None"
def find_vae_key(vae_name, vae_hash=None):
"""Determines the config parameter name to use for the VAE based on the parameters in the infotext.
If vae_hash is provided, this function will return the name of any local VAE file that matches the hash.
If vae_hash is None, this function will return vae_name if any local VAE files are named vae_name
"""
if vae_name == "None":
return vae_name
if vae_hash is not None and (matched := vae_hash_to_filename.get(vae_hash)):
if vae_name in matched or vae_name.lower() in matched:
return vae_name
return matched[0]
else:
if vae_name.lower() in [vae_filename.lower() for vae_filename in vae_dict.keys()]:
return vae_name
return None

View File

@ -40,6 +40,7 @@ from modules.sd_samplers import samplers, samplers_for_img2img
from modules.textual_inversion import textual_inversion
import modules.hypernetworks.ui
from modules.generation_parameters_copypaste import image_from_url_text
from modules.sd_vae import is_valid_vae
import modules.extras
warnings.filterwarnings("default" if opts.show_warnings else "ignore", category=UserWarning)
@ -366,6 +367,9 @@ def apply_setting(key, value):
value = ckpt_info.title
else:
return gr.update()
if key == 'sd_vae' and not is_valid_vae(value):
# ignore invalid vaes
return gr.update()
comp_args = opts.data_labels[key].component_args
if comp_args and isinstance(comp_args, dict) and comp_args.get('visible') is False: