added VAE and VAE hash to image generation params; 'Send to' buttons in PNG Info now load VAE from image's generation params

This commit is contained in:
Liam 2022-12-28 09:03:18 -05:00
parent 8850fc23b6
commit 3682ae2e66
5 changed files with 48 additions and 14 deletions

View File

@ -7,7 +7,7 @@ from pathlib import Path
import gradio as gr
from modules.shared import script_path
from modules import shared, ui_tempdir
from modules import shared, ui_tempdir, sd_vae
import tempfile
from PIL import Image
@ -83,6 +83,7 @@ def integrate_settings_paste_fields(component_dict):
'sd_model_checkpoint': 'Model hash',
'eta_noise_seed_delta': 'ENSD',
'initial_noise_multiplier': 'Noise multiplier',
'sd_vae': 'VAE',
}
settings_paste_fields = [
(component_dict[k], lambda d, k=k, v=v: ui.apply_setting(k, d.get(v, None)))
@ -281,6 +282,11 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
restore_old_hires_fix_params(res)
if "VAE" in res:
vae_name = res["VAE"]
vae_hash = res.get("VAE hash", None)
res["VAE"] = sd_vae.find_vae_key(vae_name, vae_hash)
return res

View File

@ -449,6 +449,8 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
"Eta": (None if p.sampler is None or p.sampler.eta == p.sampler.default_eta else p.sampler.eta),
"Clip skip": None if clip_skip <= 1 else clip_skip,
"ENSD": None if opts.eta_noise_seed_delta == 0 else opts.eta_noise_seed_delta,
"VAE": "None" if shared.loaded_vae_file is None else os.path.split(shared.loaded_vae_file)[1].removesuffix(".pt"),
"VAE hash": None if shared.loaded_vae_file is None else sd_models.model_hash(shared.loaded_vae_file),
}
generation_params.update(p.extra_generation_params)

View File

@ -2,10 +2,11 @@ import torch
import os
import collections
from collections import namedtuple
from modules import shared, devices, script_callbacks
from modules import shared, devices, script_callbacks, sd_models
from modules.paths import models_path
import glob
from copy import deepcopy
from collections import defaultdict
model_dir = "Stable-diffusion"
@ -24,11 +25,11 @@ default_vae_list = ["auto", "None"]
default_vae_values = [default_vae_dict[x] for x in default_vae_list]
vae_dict = dict(default_vae_dict)
vae_list = list(default_vae_list)
vae_hash_to_filename = defaultdict(list)
first_load = True
base_vae = None
loaded_vae_file = None
checkpoint_info = None
checkpoints_loaded = collections.OrderedDict()
@ -42,7 +43,7 @@ def get_base_vae(model):
def store_base_vae(model):
global base_vae, checkpoint_info
if checkpoint_info != model.sd_checkpoint_info:
assert not loaded_vae_file, "Trying to store non-base VAE!"
assert not shared.loaded_vae_file, "Trying to store non-base VAE!"
base_vae = deepcopy(model.first_stage_model.state_dict())
checkpoint_info = model.sd_checkpoint_info
@ -54,11 +55,11 @@ def delete_base_vae():
def restore_base_vae(model):
global loaded_vae_file
if base_vae is not None and checkpoint_info == model.sd_checkpoint_info:
print("Restoring base VAE")
_load_vae_dict(model, base_vae)
loaded_vae_file = None
shared.loaded_vae_file = None
shared.opts.sd_vae = "None"
delete_base_vae()
@ -67,7 +68,7 @@ def get_filename(filepath):
def refresh_vae_list(vae_path=vae_path, model_path=model_path):
global vae_dict, vae_list
global vae_dict, vae_list, vae_hash_to_filename
res = {}
candidates = [
*glob.iglob(os.path.join(model_path, '**/*.vae.ckpt'), recursive=True),
@ -77,9 +78,11 @@ def refresh_vae_list(vae_path=vae_path, model_path=model_path):
]
if shared.cmd_opts.vae_path is not None and os.path.isfile(shared.cmd_opts.vae_path):
candidates.append(shared.cmd_opts.vae_path)
vae_hash_to_filename.clear()
for filepath in candidates:
name = get_filename(filepath)
res[name] = filepath
vae_hash_to_filename[sd_models.model_hash(filepath)].append(name)
vae_list.clear()
vae_list.extend(default_vae_list)
vae_list.extend(list(res.keys()))
@ -148,7 +151,7 @@ def resolve_vae(checkpoint_file=None, vae_file="auto"):
def load_vae(model, vae_file=None):
global first_load, vae_dict, vae_list, loaded_vae_file
global first_load, vae_dict, vae_list
# save_settings = False
cache_enabled = shared.opts.sd_vae_checkpoint_cache > 0
@ -182,11 +185,10 @@ def load_vae(model, vae_file=None):
if vae_opt not in vae_dict:
vae_dict[vae_opt] = vae_file
vae_list.append(vae_opt)
elif loaded_vae_file:
shared.loaded_vae_file = vae_file
elif shared.loaded_vae_file:
restore_base_vae(model)
loaded_vae_file = vae_file
first_load = False
@ -196,8 +198,7 @@ def _load_vae_dict(model, vae_dict_1):
model.first_stage_model.to(devices.dtype_vae)
def clear_loaded_vae():
global loaded_vae_file
loaded_vae_file = None
shared.loaded_vae_file = None
def reload_vae_weights(sd_model=None, vae_file="auto"):
from modules import lowvram, devices, sd_hijack
@ -209,7 +210,7 @@ def reload_vae_weights(sd_model=None, vae_file="auto"):
checkpoint_file = checkpoint_info.filename
vae_file = resolve_vae(checkpoint_file, vae_file=vae_file)
if loaded_vae_file == vae_file:
if shared.loaded_vae_file == vae_file:
return
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
@ -229,3 +230,22 @@ def reload_vae_weights(sd_model=None, vae_file="auto"):
print("VAE Weights loaded.")
return sd_model
def is_valid_vae(vae_file: str):
return vae_file in vae_dict
def find_vae_key(vae_name, vae_hash=None):
"""Determines the config parameter name to use for the VAE based on the parameters in the infotext.
If vae_hash is provided, this function will return the name of any local VAE file that matches the hash.
If vae_hash is None, this function will return vae_name if any local VAE files are named vae_name
"""
if vae_hash is not None and (matched := vae_hash_to_filename.get(vae_hash)):
if vae_name in matched or vae_name.lower() in matched:
return vae_name
return matched[0]
else:
if vae_name.lower() in [vae_filename.lower() for vae_filename in vae_list]:
return vae_name
return None

View File

@ -142,6 +142,8 @@ os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True)
hypernetworks = {}
loaded_hypernetwork = None
loaded_vae_file = None
def reload_hypernetworks():
from modules.hypernetworks import hypernetwork

View File

@ -40,6 +40,7 @@ from modules.sd_samplers import samplers, samplers_for_img2img
import modules.textual_inversion.ui
import modules.hypernetworks.ui
from modules.generation_parameters_copypaste import image_from_url_text
from modules.sd_vae import is_valid_vae
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI
mimetypes.init()
@ -495,6 +496,9 @@ def apply_setting(key, value):
value = ckpt_info.title
else:
return gr.update()
if key == 'sd_vae' and not is_valid_vae(value):
# ignore invalid vaes
return gr.update()
comp_args = opts.data_labels[key].component_args
if comp_args and isinstance(comp_args, dict) and comp_args.get('visible') is False: