added VAE and VAE hash to image generation params; 'Send to' buttons in PNG Info now load VAE from image's generation params
This commit is contained in:
parent
8850fc23b6
commit
3682ae2e66
|
@ -7,7 +7,7 @@ from pathlib import Path
|
|||
|
||||
import gradio as gr
|
||||
from modules.shared import script_path
|
||||
from modules import shared, ui_tempdir
|
||||
from modules import shared, ui_tempdir, sd_vae
|
||||
import tempfile
|
||||
from PIL import Image
|
||||
|
||||
|
@ -83,6 +83,7 @@ def integrate_settings_paste_fields(component_dict):
|
|||
'sd_model_checkpoint': 'Model hash',
|
||||
'eta_noise_seed_delta': 'ENSD',
|
||||
'initial_noise_multiplier': 'Noise multiplier',
|
||||
'sd_vae': 'VAE',
|
||||
}
|
||||
settings_paste_fields = [
|
||||
(component_dict[k], lambda d, k=k, v=v: ui.apply_setting(k, d.get(v, None)))
|
||||
|
@ -281,6 +282,11 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
|
|||
|
||||
restore_old_hires_fix_params(res)
|
||||
|
||||
if "VAE" in res:
|
||||
vae_name = res["VAE"]
|
||||
vae_hash = res.get("VAE hash", None)
|
||||
res["VAE"] = sd_vae.find_vae_key(vae_name, vae_hash)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
|
|
|
@ -449,6 +449,8 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
|
|||
"Eta": (None if p.sampler is None or p.sampler.eta == p.sampler.default_eta else p.sampler.eta),
|
||||
"Clip skip": None if clip_skip <= 1 else clip_skip,
|
||||
"ENSD": None if opts.eta_noise_seed_delta == 0 else opts.eta_noise_seed_delta,
|
||||
"VAE": "None" if shared.loaded_vae_file is None else os.path.split(shared.loaded_vae_file)[1].removesuffix(".pt"),
|
||||
"VAE hash": None if shared.loaded_vae_file is None else sd_models.model_hash(shared.loaded_vae_file),
|
||||
}
|
||||
|
||||
generation_params.update(p.extra_generation_params)
|
||||
|
|
|
@ -2,10 +2,11 @@ import torch
|
|||
import os
|
||||
import collections
|
||||
from collections import namedtuple
|
||||
from modules import shared, devices, script_callbacks
|
||||
from modules import shared, devices, script_callbacks, sd_models
|
||||
from modules.paths import models_path
|
||||
import glob
|
||||
from copy import deepcopy
|
||||
from collections import defaultdict
|
||||
|
||||
|
||||
model_dir = "Stable-diffusion"
|
||||
|
@ -24,11 +25,11 @@ default_vae_list = ["auto", "None"]
|
|||
default_vae_values = [default_vae_dict[x] for x in default_vae_list]
|
||||
vae_dict = dict(default_vae_dict)
|
||||
vae_list = list(default_vae_list)
|
||||
vae_hash_to_filename = defaultdict(list)
|
||||
first_load = True
|
||||
|
||||
|
||||
base_vae = None
|
||||
loaded_vae_file = None
|
||||
checkpoint_info = None
|
||||
|
||||
checkpoints_loaded = collections.OrderedDict()
|
||||
|
@ -42,7 +43,7 @@ def get_base_vae(model):
|
|||
def store_base_vae(model):
|
||||
global base_vae, checkpoint_info
|
||||
if checkpoint_info != model.sd_checkpoint_info:
|
||||
assert not loaded_vae_file, "Trying to store non-base VAE!"
|
||||
assert not shared.loaded_vae_file, "Trying to store non-base VAE!"
|
||||
base_vae = deepcopy(model.first_stage_model.state_dict())
|
||||
checkpoint_info = model.sd_checkpoint_info
|
||||
|
||||
|
@ -54,11 +55,11 @@ def delete_base_vae():
|
|||
|
||||
|
||||
def restore_base_vae(model):
|
||||
global loaded_vae_file
|
||||
if base_vae is not None and checkpoint_info == model.sd_checkpoint_info:
|
||||
print("Restoring base VAE")
|
||||
_load_vae_dict(model, base_vae)
|
||||
loaded_vae_file = None
|
||||
shared.loaded_vae_file = None
|
||||
shared.opts.sd_vae = "None"
|
||||
delete_base_vae()
|
||||
|
||||
|
||||
|
@ -67,7 +68,7 @@ def get_filename(filepath):
|
|||
|
||||
|
||||
def refresh_vae_list(vae_path=vae_path, model_path=model_path):
|
||||
global vae_dict, vae_list
|
||||
global vae_dict, vae_list, vae_hash_to_filename
|
||||
res = {}
|
||||
candidates = [
|
||||
*glob.iglob(os.path.join(model_path, '**/*.vae.ckpt'), recursive=True),
|
||||
|
@ -77,9 +78,11 @@ def refresh_vae_list(vae_path=vae_path, model_path=model_path):
|
|||
]
|
||||
if shared.cmd_opts.vae_path is not None and os.path.isfile(shared.cmd_opts.vae_path):
|
||||
candidates.append(shared.cmd_opts.vae_path)
|
||||
vae_hash_to_filename.clear()
|
||||
for filepath in candidates:
|
||||
name = get_filename(filepath)
|
||||
res[name] = filepath
|
||||
vae_hash_to_filename[sd_models.model_hash(filepath)].append(name)
|
||||
vae_list.clear()
|
||||
vae_list.extend(default_vae_list)
|
||||
vae_list.extend(list(res.keys()))
|
||||
|
@ -148,7 +151,7 @@ def resolve_vae(checkpoint_file=None, vae_file="auto"):
|
|||
|
||||
|
||||
def load_vae(model, vae_file=None):
|
||||
global first_load, vae_dict, vae_list, loaded_vae_file
|
||||
global first_load, vae_dict, vae_list
|
||||
# save_settings = False
|
||||
|
||||
cache_enabled = shared.opts.sd_vae_checkpoint_cache > 0
|
||||
|
@ -182,11 +185,10 @@ def load_vae(model, vae_file=None):
|
|||
if vae_opt not in vae_dict:
|
||||
vae_dict[vae_opt] = vae_file
|
||||
vae_list.append(vae_opt)
|
||||
elif loaded_vae_file:
|
||||
shared.loaded_vae_file = vae_file
|
||||
elif shared.loaded_vae_file:
|
||||
restore_base_vae(model)
|
||||
|
||||
loaded_vae_file = vae_file
|
||||
|
||||
first_load = False
|
||||
|
||||
|
||||
|
@ -196,8 +198,7 @@ def _load_vae_dict(model, vae_dict_1):
|
|||
model.first_stage_model.to(devices.dtype_vae)
|
||||
|
||||
def clear_loaded_vae():
|
||||
global loaded_vae_file
|
||||
loaded_vae_file = None
|
||||
shared.loaded_vae_file = None
|
||||
|
||||
def reload_vae_weights(sd_model=None, vae_file="auto"):
|
||||
from modules import lowvram, devices, sd_hijack
|
||||
|
@ -209,7 +210,7 @@ def reload_vae_weights(sd_model=None, vae_file="auto"):
|
|||
checkpoint_file = checkpoint_info.filename
|
||||
vae_file = resolve_vae(checkpoint_file, vae_file=vae_file)
|
||||
|
||||
if loaded_vae_file == vae_file:
|
||||
if shared.loaded_vae_file == vae_file:
|
||||
return
|
||||
|
||||
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
||||
|
@ -229,3 +230,22 @@ def reload_vae_weights(sd_model=None, vae_file="auto"):
|
|||
|
||||
print("VAE Weights loaded.")
|
||||
return sd_model
|
||||
|
||||
|
||||
def is_valid_vae(vae_file: str):
|
||||
return vae_file in vae_dict
|
||||
|
||||
|
||||
def find_vae_key(vae_name, vae_hash=None):
|
||||
"""Determines the config parameter name to use for the VAE based on the parameters in the infotext.
|
||||
If vae_hash is provided, this function will return the name of any local VAE file that matches the hash.
|
||||
If vae_hash is None, this function will return vae_name if any local VAE files are named vae_name
|
||||
"""
|
||||
if vae_hash is not None and (matched := vae_hash_to_filename.get(vae_hash)):
|
||||
if vae_name in matched or vae_name.lower() in matched:
|
||||
return vae_name
|
||||
return matched[0]
|
||||
else:
|
||||
if vae_name.lower() in [vae_filename.lower() for vae_filename in vae_list]:
|
||||
return vae_name
|
||||
return None
|
||||
|
|
|
@ -142,6 +142,8 @@ os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True)
|
|||
hypernetworks = {}
|
||||
loaded_hypernetwork = None
|
||||
|
||||
loaded_vae_file = None
|
||||
|
||||
|
||||
def reload_hypernetworks():
|
||||
from modules.hypernetworks import hypernetwork
|
||||
|
|
|
@ -40,6 +40,7 @@ from modules.sd_samplers import samplers, samplers_for_img2img
|
|||
import modules.textual_inversion.ui
|
||||
import modules.hypernetworks.ui
|
||||
from modules.generation_parameters_copypaste import image_from_url_text
|
||||
from modules.sd_vae import is_valid_vae
|
||||
|
||||
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI
|
||||
mimetypes.init()
|
||||
|
@ -495,6 +496,9 @@ def apply_setting(key, value):
|
|||
value = ckpt_info.title
|
||||
else:
|
||||
return gr.update()
|
||||
if key == 'sd_vae' and not is_valid_vae(value):
|
||||
# ignore invalid vaes
|
||||
return gr.update()
|
||||
|
||||
comp_args = opts.data_labels[key].component_args
|
||||
if comp_args and isinstance(comp_args, dict) and comp_args.get('visible') is False:
|
||||
|
|
Loading…
Reference in New Issue
Block a user