Merge 66cad3ab6f
into ea9bd9fc74
This commit is contained in:
commit
3ac55dffab
|
@ -8,7 +8,7 @@ from pathlib import Path
|
|||
|
||||
import gradio as gr
|
||||
from modules.paths import data_path
|
||||
from modules import shared, ui_tempdir, script_callbacks
|
||||
from modules import shared, ui_tempdir, script_callbacks, sd_vae
|
||||
import tempfile
|
||||
from PIL import Image
|
||||
|
||||
|
@ -282,6 +282,11 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
|
|||
|
||||
restore_old_hires_fix_params(res)
|
||||
|
||||
if "VAE" in res:
|
||||
vae_name = res["VAE"]
|
||||
vae_hash = res.get("VAE hash", None)
|
||||
res["VAE"] = sd_vae.find_vae_key(vae_name, vae_hash)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
|
@ -295,7 +300,8 @@ infotext_to_setting_name_mapping = [
|
|||
('Noise multiplier', 'initial_noise_multiplier'),
|
||||
('Eta', 'eta_ancestral'),
|
||||
('Eta DDIM', 'eta_ddim'),
|
||||
('Discard penultimate sigma', 'always_discard_next_to_last_sigma')
|
||||
('Discard penultimate sigma', 'always_discard_next_to_last_sigma'),
|
||||
('VAE', 'sd_vae'),
|
||||
]
|
||||
|
||||
|
||||
|
|
|
@ -459,6 +459,8 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
|
|||
"Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None,
|
||||
"Clip skip": None if clip_skip <= 1 else clip_skip,
|
||||
"ENSD": None if opts.eta_noise_seed_delta == 0 else opts.eta_noise_seed_delta,
|
||||
"VAE": "None" if sd_vae.loaded_vae_file is None else os.path.split(sd_vae.loaded_vae_file)[1].removesuffix(".pt"),
|
||||
"VAE hash": None if sd_vae.loaded_vae_file is None else sd_models.model_hash(sd_vae.loaded_vae_file),
|
||||
}
|
||||
|
||||
generation_params.update(p.extra_generation_params)
|
||||
|
|
|
@ -6,6 +6,7 @@ from collections import namedtuple
|
|||
from modules import paths, shared, devices, script_callbacks, sd_models
|
||||
import glob
|
||||
from copy import deepcopy
|
||||
from collections import defaultdict
|
||||
|
||||
|
||||
vae_path = os.path.abspath(os.path.join(paths.models_path, "VAE"))
|
||||
|
@ -17,6 +18,8 @@ base_vae = None
|
|||
loaded_vae_file = None
|
||||
checkpoint_info = None
|
||||
|
||||
vae_hash_to_filename = defaultdict(list)
|
||||
|
||||
checkpoints_loaded = collections.OrderedDict()
|
||||
|
||||
def get_base_vae(model):
|
||||
|
@ -82,9 +85,11 @@ def refresh_vae_list():
|
|||
for path in paths:
|
||||
candidates += glob.iglob(path, recursive=True)
|
||||
|
||||
vae_hash_to_filename.clear()
|
||||
for filepath in candidates:
|
||||
name = get_filename(filepath)
|
||||
vae_dict[name] = filepath
|
||||
vae_hash_to_filename[sd_models.model_hash(filepath)].append(name)
|
||||
|
||||
|
||||
def find_vae_near_checkpoint(checkpoint_file):
|
||||
|
@ -159,6 +164,7 @@ def load_vae(model, vae_file=None, vae_source="from unknown source"):
|
|||
vae_opt = get_filename(vae_file)
|
||||
if vae_opt not in vae_dict:
|
||||
vae_dict[vae_opt] = vae_file
|
||||
vae_hash_to_filename[sd_models.model_hash(vae_file)] = vae_opt
|
||||
|
||||
elif loaded_vae_file:
|
||||
restore_base_vae(model)
|
||||
|
@ -214,3 +220,28 @@ def reload_vae_weights(sd_model=None, vae_file=unspecified):
|
|||
|
||||
print("VAE weights loaded.")
|
||||
return sd_model
|
||||
|
||||
|
||||
def is_valid_vae(vae_file: str):
|
||||
"""
|
||||
Returns true if the vae_file name exists in the cache of vae files
|
||||
A vae_file of "None" is valid because it represents the "None" option in the vae_se
|
||||
"""
|
||||
return vae_file in vae_dict or vae_file == "None"
|
||||
|
||||
|
||||
def find_vae_key(vae_name, vae_hash=None):
|
||||
"""Determines the config parameter name to use for the VAE based on the parameters in the infotext.
|
||||
If vae_hash is provided, this function will return the name of any local VAE file that matches the hash.
|
||||
If vae_hash is None, this function will return vae_name if any local VAE files are named vae_name
|
||||
"""
|
||||
if vae_name == "None":
|
||||
return vae_name
|
||||
if vae_hash is not None and (matched := vae_hash_to_filename.get(vae_hash)):
|
||||
if vae_name in matched or vae_name.lower() in matched:
|
||||
return vae_name
|
||||
return matched[0]
|
||||
else:
|
||||
if vae_name.lower() in [vae_filename.lower() for vae_filename in vae_dict.keys()]:
|
||||
return vae_name
|
||||
return None
|
||||
|
|
|
@ -40,6 +40,7 @@ from modules.sd_samplers import samplers, samplers_for_img2img
|
|||
from modules.textual_inversion import textual_inversion
|
||||
import modules.hypernetworks.ui
|
||||
from modules.generation_parameters_copypaste import image_from_url_text
|
||||
from modules.sd_vae import is_valid_vae
|
||||
import modules.extras
|
||||
|
||||
warnings.filterwarnings("default" if opts.show_warnings else "ignore", category=UserWarning)
|
||||
|
@ -366,6 +367,9 @@ def apply_setting(key, value):
|
|||
value = ckpt_info.title
|
||||
else:
|
||||
return gr.update()
|
||||
if key == 'sd_vae' and not is_valid_vae(value):
|
||||
# ignore invalid vaes
|
||||
return gr.update()
|
||||
|
||||
comp_args = opts.data_labels[key].component_args
|
||||
if comp_args and isinstance(comp_args, dict) and comp_args.get('visible') is False:
|
||||
|
|
Loading…
Reference in New Issue
Block a user