Merge 66cad3ab6f
into ea9bd9fc74
This commit is contained in:
commit
3ac55dffab
|
@ -8,7 +8,7 @@ from pathlib import Path
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
from modules.paths import data_path
|
from modules.paths import data_path
|
||||||
from modules import shared, ui_tempdir, script_callbacks
|
from modules import shared, ui_tempdir, script_callbacks, sd_vae
|
||||||
import tempfile
|
import tempfile
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
@ -282,6 +282,11 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
|
||||||
|
|
||||||
restore_old_hires_fix_params(res)
|
restore_old_hires_fix_params(res)
|
||||||
|
|
||||||
|
if "VAE" in res:
|
||||||
|
vae_name = res["VAE"]
|
||||||
|
vae_hash = res.get("VAE hash", None)
|
||||||
|
res["VAE"] = sd_vae.find_vae_key(vae_name, vae_hash)
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
@ -295,7 +300,8 @@ infotext_to_setting_name_mapping = [
|
||||||
('Noise multiplier', 'initial_noise_multiplier'),
|
('Noise multiplier', 'initial_noise_multiplier'),
|
||||||
('Eta', 'eta_ancestral'),
|
('Eta', 'eta_ancestral'),
|
||||||
('Eta DDIM', 'eta_ddim'),
|
('Eta DDIM', 'eta_ddim'),
|
||||||
('Discard penultimate sigma', 'always_discard_next_to_last_sigma')
|
('Discard penultimate sigma', 'always_discard_next_to_last_sigma'),
|
||||||
|
('VAE', 'sd_vae'),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -459,6 +459,8 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
|
||||||
"Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None,
|
"Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None,
|
||||||
"Clip skip": None if clip_skip <= 1 else clip_skip,
|
"Clip skip": None if clip_skip <= 1 else clip_skip,
|
||||||
"ENSD": None if opts.eta_noise_seed_delta == 0 else opts.eta_noise_seed_delta,
|
"ENSD": None if opts.eta_noise_seed_delta == 0 else opts.eta_noise_seed_delta,
|
||||||
|
"VAE": "None" if sd_vae.loaded_vae_file is None else os.path.split(sd_vae.loaded_vae_file)[1].removesuffix(".pt"),
|
||||||
|
"VAE hash": None if sd_vae.loaded_vae_file is None else sd_models.model_hash(sd_vae.loaded_vae_file),
|
||||||
}
|
}
|
||||||
|
|
||||||
generation_params.update(p.extra_generation_params)
|
generation_params.update(p.extra_generation_params)
|
||||||
|
|
|
@ -6,6 +6,7 @@ from collections import namedtuple
|
||||||
from modules import paths, shared, devices, script_callbacks, sd_models
|
from modules import paths, shared, devices, script_callbacks, sd_models
|
||||||
import glob
|
import glob
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
|
||||||
vae_path = os.path.abspath(os.path.join(paths.models_path, "VAE"))
|
vae_path = os.path.abspath(os.path.join(paths.models_path, "VAE"))
|
||||||
|
@ -17,6 +18,8 @@ base_vae = None
|
||||||
loaded_vae_file = None
|
loaded_vae_file = None
|
||||||
checkpoint_info = None
|
checkpoint_info = None
|
||||||
|
|
||||||
|
vae_hash_to_filename = defaultdict(list)
|
||||||
|
|
||||||
checkpoints_loaded = collections.OrderedDict()
|
checkpoints_loaded = collections.OrderedDict()
|
||||||
|
|
||||||
def get_base_vae(model):
|
def get_base_vae(model):
|
||||||
|
@ -82,9 +85,11 @@ def refresh_vae_list():
|
||||||
for path in paths:
|
for path in paths:
|
||||||
candidates += glob.iglob(path, recursive=True)
|
candidates += glob.iglob(path, recursive=True)
|
||||||
|
|
||||||
|
vae_hash_to_filename.clear()
|
||||||
for filepath in candidates:
|
for filepath in candidates:
|
||||||
name = get_filename(filepath)
|
name = get_filename(filepath)
|
||||||
vae_dict[name] = filepath
|
vae_dict[name] = filepath
|
||||||
|
vae_hash_to_filename[sd_models.model_hash(filepath)].append(name)
|
||||||
|
|
||||||
|
|
||||||
def find_vae_near_checkpoint(checkpoint_file):
|
def find_vae_near_checkpoint(checkpoint_file):
|
||||||
|
@ -159,6 +164,7 @@ def load_vae(model, vae_file=None, vae_source="from unknown source"):
|
||||||
vae_opt = get_filename(vae_file)
|
vae_opt = get_filename(vae_file)
|
||||||
if vae_opt not in vae_dict:
|
if vae_opt not in vae_dict:
|
||||||
vae_dict[vae_opt] = vae_file
|
vae_dict[vae_opt] = vae_file
|
||||||
|
vae_hash_to_filename[sd_models.model_hash(vae_file)] = vae_opt
|
||||||
|
|
||||||
elif loaded_vae_file:
|
elif loaded_vae_file:
|
||||||
restore_base_vae(model)
|
restore_base_vae(model)
|
||||||
|
@ -214,3 +220,28 @@ def reload_vae_weights(sd_model=None, vae_file=unspecified):
|
||||||
|
|
||||||
print("VAE weights loaded.")
|
print("VAE weights loaded.")
|
||||||
return sd_model
|
return sd_model
|
||||||
|
|
||||||
|
|
||||||
|
def is_valid_vae(vae_file: str):
|
||||||
|
"""
|
||||||
|
Returns true if the vae_file name exists in the cache of vae files
|
||||||
|
A vae_file of "None" is valid because it represents the "None" option in the vae_se
|
||||||
|
"""
|
||||||
|
return vae_file in vae_dict or vae_file == "None"
|
||||||
|
|
||||||
|
|
||||||
|
def find_vae_key(vae_name, vae_hash=None):
|
||||||
|
"""Determines the config parameter name to use for the VAE based on the parameters in the infotext.
|
||||||
|
If vae_hash is provided, this function will return the name of any local VAE file that matches the hash.
|
||||||
|
If vae_hash is None, this function will return vae_name if any local VAE files are named vae_name
|
||||||
|
"""
|
||||||
|
if vae_name == "None":
|
||||||
|
return vae_name
|
||||||
|
if vae_hash is not None and (matched := vae_hash_to_filename.get(vae_hash)):
|
||||||
|
if vae_name in matched or vae_name.lower() in matched:
|
||||||
|
return vae_name
|
||||||
|
return matched[0]
|
||||||
|
else:
|
||||||
|
if vae_name.lower() in [vae_filename.lower() for vae_filename in vae_dict.keys()]:
|
||||||
|
return vae_name
|
||||||
|
return None
|
||||||
|
|
|
@ -40,6 +40,7 @@ from modules.sd_samplers import samplers, samplers_for_img2img
|
||||||
from modules.textual_inversion import textual_inversion
|
from modules.textual_inversion import textual_inversion
|
||||||
import modules.hypernetworks.ui
|
import modules.hypernetworks.ui
|
||||||
from modules.generation_parameters_copypaste import image_from_url_text
|
from modules.generation_parameters_copypaste import image_from_url_text
|
||||||
|
from modules.sd_vae import is_valid_vae
|
||||||
import modules.extras
|
import modules.extras
|
||||||
|
|
||||||
warnings.filterwarnings("default" if opts.show_warnings else "ignore", category=UserWarning)
|
warnings.filterwarnings("default" if opts.show_warnings else "ignore", category=UserWarning)
|
||||||
|
@ -366,6 +367,9 @@ def apply_setting(key, value):
|
||||||
value = ckpt_info.title
|
value = ckpt_info.title
|
||||||
else:
|
else:
|
||||||
return gr.update()
|
return gr.update()
|
||||||
|
if key == 'sd_vae' and not is_valid_vae(value):
|
||||||
|
# ignore invalid vaes
|
||||||
|
return gr.update()
|
||||||
|
|
||||||
comp_args = opts.data_labels[key].component_args
|
comp_args = opts.data_labels[key].component_args
|
||||||
if comp_args and isinstance(comp_args, dict) and comp_args.get('visible') is False:
|
if comp_args and isinstance(comp_args, dict) and comp_args.get('visible') is False:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user