added VAE and VAE hash to image generation params; 'Send to' buttons in PNG Info now load VAE from image's generation params
This commit is contained in:
parent
8850fc23b6
commit
3682ae2e66
|
@ -7,7 +7,7 @@ from pathlib import Path
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
from modules.shared import script_path
|
from modules.shared import script_path
|
||||||
from modules import shared, ui_tempdir
|
from modules import shared, ui_tempdir, sd_vae
|
||||||
import tempfile
|
import tempfile
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
@ -83,6 +83,7 @@ def integrate_settings_paste_fields(component_dict):
|
||||||
'sd_model_checkpoint': 'Model hash',
|
'sd_model_checkpoint': 'Model hash',
|
||||||
'eta_noise_seed_delta': 'ENSD',
|
'eta_noise_seed_delta': 'ENSD',
|
||||||
'initial_noise_multiplier': 'Noise multiplier',
|
'initial_noise_multiplier': 'Noise multiplier',
|
||||||
|
'sd_vae': 'VAE',
|
||||||
}
|
}
|
||||||
settings_paste_fields = [
|
settings_paste_fields = [
|
||||||
(component_dict[k], lambda d, k=k, v=v: ui.apply_setting(k, d.get(v, None)))
|
(component_dict[k], lambda d, k=k, v=v: ui.apply_setting(k, d.get(v, None)))
|
||||||
|
@ -281,6 +282,11 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
|
||||||
|
|
||||||
restore_old_hires_fix_params(res)
|
restore_old_hires_fix_params(res)
|
||||||
|
|
||||||
|
if "VAE" in res:
|
||||||
|
vae_name = res["VAE"]
|
||||||
|
vae_hash = res.get("VAE hash", None)
|
||||||
|
res["VAE"] = sd_vae.find_vae_key(vae_name, vae_hash)
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -449,6 +449,8 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
|
||||||
"Eta": (None if p.sampler is None or p.sampler.eta == p.sampler.default_eta else p.sampler.eta),
|
"Eta": (None if p.sampler is None or p.sampler.eta == p.sampler.default_eta else p.sampler.eta),
|
||||||
"Clip skip": None if clip_skip <= 1 else clip_skip,
|
"Clip skip": None if clip_skip <= 1 else clip_skip,
|
||||||
"ENSD": None if opts.eta_noise_seed_delta == 0 else opts.eta_noise_seed_delta,
|
"ENSD": None if opts.eta_noise_seed_delta == 0 else opts.eta_noise_seed_delta,
|
||||||
|
"VAE": "None" if shared.loaded_vae_file is None else os.path.split(shared.loaded_vae_file)[1].removesuffix(".pt"),
|
||||||
|
"VAE hash": None if shared.loaded_vae_file is None else sd_models.model_hash(shared.loaded_vae_file),
|
||||||
}
|
}
|
||||||
|
|
||||||
generation_params.update(p.extra_generation_params)
|
generation_params.update(p.extra_generation_params)
|
||||||
|
|
|
@ -2,10 +2,11 @@ import torch
|
||||||
import os
|
import os
|
||||||
import collections
|
import collections
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from modules import shared, devices, script_callbacks
|
from modules import shared, devices, script_callbacks, sd_models
|
||||||
from modules.paths import models_path
|
from modules.paths import models_path
|
||||||
import glob
|
import glob
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
|
||||||
model_dir = "Stable-diffusion"
|
model_dir = "Stable-diffusion"
|
||||||
|
@ -24,11 +25,11 @@ default_vae_list = ["auto", "None"]
|
||||||
default_vae_values = [default_vae_dict[x] for x in default_vae_list]
|
default_vae_values = [default_vae_dict[x] for x in default_vae_list]
|
||||||
vae_dict = dict(default_vae_dict)
|
vae_dict = dict(default_vae_dict)
|
||||||
vae_list = list(default_vae_list)
|
vae_list = list(default_vae_list)
|
||||||
|
vae_hash_to_filename = defaultdict(list)
|
||||||
first_load = True
|
first_load = True
|
||||||
|
|
||||||
|
|
||||||
base_vae = None
|
base_vae = None
|
||||||
loaded_vae_file = None
|
|
||||||
checkpoint_info = None
|
checkpoint_info = None
|
||||||
|
|
||||||
checkpoints_loaded = collections.OrderedDict()
|
checkpoints_loaded = collections.OrderedDict()
|
||||||
|
@ -42,7 +43,7 @@ def get_base_vae(model):
|
||||||
def store_base_vae(model):
|
def store_base_vae(model):
|
||||||
global base_vae, checkpoint_info
|
global base_vae, checkpoint_info
|
||||||
if checkpoint_info != model.sd_checkpoint_info:
|
if checkpoint_info != model.sd_checkpoint_info:
|
||||||
assert not loaded_vae_file, "Trying to store non-base VAE!"
|
assert not shared.loaded_vae_file, "Trying to store non-base VAE!"
|
||||||
base_vae = deepcopy(model.first_stage_model.state_dict())
|
base_vae = deepcopy(model.first_stage_model.state_dict())
|
||||||
checkpoint_info = model.sd_checkpoint_info
|
checkpoint_info = model.sd_checkpoint_info
|
||||||
|
|
||||||
|
@ -54,11 +55,11 @@ def delete_base_vae():
|
||||||
|
|
||||||
|
|
||||||
def restore_base_vae(model):
|
def restore_base_vae(model):
|
||||||
global loaded_vae_file
|
|
||||||
if base_vae is not None and checkpoint_info == model.sd_checkpoint_info:
|
if base_vae is not None and checkpoint_info == model.sd_checkpoint_info:
|
||||||
print("Restoring base VAE")
|
print("Restoring base VAE")
|
||||||
_load_vae_dict(model, base_vae)
|
_load_vae_dict(model, base_vae)
|
||||||
loaded_vae_file = None
|
shared.loaded_vae_file = None
|
||||||
|
shared.opts.sd_vae = "None"
|
||||||
delete_base_vae()
|
delete_base_vae()
|
||||||
|
|
||||||
|
|
||||||
|
@ -67,7 +68,7 @@ def get_filename(filepath):
|
||||||
|
|
||||||
|
|
||||||
def refresh_vae_list(vae_path=vae_path, model_path=model_path):
|
def refresh_vae_list(vae_path=vae_path, model_path=model_path):
|
||||||
global vae_dict, vae_list
|
global vae_dict, vae_list, vae_hash_to_filename
|
||||||
res = {}
|
res = {}
|
||||||
candidates = [
|
candidates = [
|
||||||
*glob.iglob(os.path.join(model_path, '**/*.vae.ckpt'), recursive=True),
|
*glob.iglob(os.path.join(model_path, '**/*.vae.ckpt'), recursive=True),
|
||||||
|
@ -77,9 +78,11 @@ def refresh_vae_list(vae_path=vae_path, model_path=model_path):
|
||||||
]
|
]
|
||||||
if shared.cmd_opts.vae_path is not None and os.path.isfile(shared.cmd_opts.vae_path):
|
if shared.cmd_opts.vae_path is not None and os.path.isfile(shared.cmd_opts.vae_path):
|
||||||
candidates.append(shared.cmd_opts.vae_path)
|
candidates.append(shared.cmd_opts.vae_path)
|
||||||
|
vae_hash_to_filename.clear()
|
||||||
for filepath in candidates:
|
for filepath in candidates:
|
||||||
name = get_filename(filepath)
|
name = get_filename(filepath)
|
||||||
res[name] = filepath
|
res[name] = filepath
|
||||||
|
vae_hash_to_filename[sd_models.model_hash(filepath)].append(name)
|
||||||
vae_list.clear()
|
vae_list.clear()
|
||||||
vae_list.extend(default_vae_list)
|
vae_list.extend(default_vae_list)
|
||||||
vae_list.extend(list(res.keys()))
|
vae_list.extend(list(res.keys()))
|
||||||
|
@ -148,7 +151,7 @@ def resolve_vae(checkpoint_file=None, vae_file="auto"):
|
||||||
|
|
||||||
|
|
||||||
def load_vae(model, vae_file=None):
|
def load_vae(model, vae_file=None):
|
||||||
global first_load, vae_dict, vae_list, loaded_vae_file
|
global first_load, vae_dict, vae_list
|
||||||
# save_settings = False
|
# save_settings = False
|
||||||
|
|
||||||
cache_enabled = shared.opts.sd_vae_checkpoint_cache > 0
|
cache_enabled = shared.opts.sd_vae_checkpoint_cache > 0
|
||||||
|
@ -182,11 +185,10 @@ def load_vae(model, vae_file=None):
|
||||||
if vae_opt not in vae_dict:
|
if vae_opt not in vae_dict:
|
||||||
vae_dict[vae_opt] = vae_file
|
vae_dict[vae_opt] = vae_file
|
||||||
vae_list.append(vae_opt)
|
vae_list.append(vae_opt)
|
||||||
elif loaded_vae_file:
|
shared.loaded_vae_file = vae_file
|
||||||
|
elif shared.loaded_vae_file:
|
||||||
restore_base_vae(model)
|
restore_base_vae(model)
|
||||||
|
|
||||||
loaded_vae_file = vae_file
|
|
||||||
|
|
||||||
first_load = False
|
first_load = False
|
||||||
|
|
||||||
|
|
||||||
|
@ -196,8 +198,7 @@ def _load_vae_dict(model, vae_dict_1):
|
||||||
model.first_stage_model.to(devices.dtype_vae)
|
model.first_stage_model.to(devices.dtype_vae)
|
||||||
|
|
||||||
def clear_loaded_vae():
|
def clear_loaded_vae():
|
||||||
global loaded_vae_file
|
shared.loaded_vae_file = None
|
||||||
loaded_vae_file = None
|
|
||||||
|
|
||||||
def reload_vae_weights(sd_model=None, vae_file="auto"):
|
def reload_vae_weights(sd_model=None, vae_file="auto"):
|
||||||
from modules import lowvram, devices, sd_hijack
|
from modules import lowvram, devices, sd_hijack
|
||||||
|
@ -209,7 +210,7 @@ def reload_vae_weights(sd_model=None, vae_file="auto"):
|
||||||
checkpoint_file = checkpoint_info.filename
|
checkpoint_file = checkpoint_info.filename
|
||||||
vae_file = resolve_vae(checkpoint_file, vae_file=vae_file)
|
vae_file = resolve_vae(checkpoint_file, vae_file=vae_file)
|
||||||
|
|
||||||
if loaded_vae_file == vae_file:
|
if shared.loaded_vae_file == vae_file:
|
||||||
return
|
return
|
||||||
|
|
||||||
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
||||||
|
@ -229,3 +230,22 @@ def reload_vae_weights(sd_model=None, vae_file="auto"):
|
||||||
|
|
||||||
print("VAE Weights loaded.")
|
print("VAE Weights loaded.")
|
||||||
return sd_model
|
return sd_model
|
||||||
|
|
||||||
|
|
||||||
|
def is_valid_vae(vae_file: str):
|
||||||
|
return vae_file in vae_dict
|
||||||
|
|
||||||
|
|
||||||
|
def find_vae_key(vae_name, vae_hash=None):
|
||||||
|
"""Determines the config parameter name to use for the VAE based on the parameters in the infotext.
|
||||||
|
If vae_hash is provided, this function will return the name of any local VAE file that matches the hash.
|
||||||
|
If vae_hash is None, this function will return vae_name if any local VAE files are named vae_name
|
||||||
|
"""
|
||||||
|
if vae_hash is not None and (matched := vae_hash_to_filename.get(vae_hash)):
|
||||||
|
if vae_name in matched or vae_name.lower() in matched:
|
||||||
|
return vae_name
|
||||||
|
return matched[0]
|
||||||
|
else:
|
||||||
|
if vae_name.lower() in [vae_filename.lower() for vae_filename in vae_list]:
|
||||||
|
return vae_name
|
||||||
|
return None
|
||||||
|
|
|
@ -142,6 +142,8 @@ os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True)
|
||||||
hypernetworks = {}
|
hypernetworks = {}
|
||||||
loaded_hypernetwork = None
|
loaded_hypernetwork = None
|
||||||
|
|
||||||
|
loaded_vae_file = None
|
||||||
|
|
||||||
|
|
||||||
def reload_hypernetworks():
|
def reload_hypernetworks():
|
||||||
from modules.hypernetworks import hypernetwork
|
from modules.hypernetworks import hypernetwork
|
||||||
|
|
|
@ -40,6 +40,7 @@ from modules.sd_samplers import samplers, samplers_for_img2img
|
||||||
import modules.textual_inversion.ui
|
import modules.textual_inversion.ui
|
||||||
import modules.hypernetworks.ui
|
import modules.hypernetworks.ui
|
||||||
from modules.generation_parameters_copypaste import image_from_url_text
|
from modules.generation_parameters_copypaste import image_from_url_text
|
||||||
|
from modules.sd_vae import is_valid_vae
|
||||||
|
|
||||||
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI
|
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI
|
||||||
mimetypes.init()
|
mimetypes.init()
|
||||||
|
@ -495,6 +496,9 @@ def apply_setting(key, value):
|
||||||
value = ckpt_info.title
|
value = ckpt_info.title
|
||||||
else:
|
else:
|
||||||
return gr.update()
|
return gr.update()
|
||||||
|
if key == 'sd_vae' and not is_valid_vae(value):
|
||||||
|
# ignore invalid vaes
|
||||||
|
return gr.update()
|
||||||
|
|
||||||
comp_args = opts.data_labels[key].component_args
|
comp_args = opts.data_labels[key].component_args
|
||||||
if comp_args and isinstance(comp_args, dict) and comp_args.get('visible') is False:
|
if comp_args and isinstance(comp_args, dict) and comp_args.get('visible') is False:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user