add more stuff to ignore when creating model from config

prevent .vae.safetensors files from being listed as stable diffusion models
This commit is contained in:
AUTOMATIC 2023-01-10 16:51:04 +03:00
parent 0c3feb202c
commit ce3f639ec8
3 changed files with 56 additions and 9 deletions

View File

@ -10,7 +10,7 @@ from modules.upscaler import Upscaler
from modules.paths import script_path, models_path from modules.paths import script_path, models_path
def load_models(model_path: str, model_url: str = None, command_path: str = None, ext_filter=None, download_name=None) -> list: def load_models(model_path: str, model_url: str = None, command_path: str = None, ext_filter=None, download_name=None, ext_blacklist=None) -> list:
""" """
A one-and done loader to try finding the desired models in specified directories. A one-and done loader to try finding the desired models in specified directories.
@ -45,6 +45,8 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
full_path = file full_path = file
if os.path.isdir(full_path): if os.path.isdir(full_path):
continue continue
if ext_blacklist is not None and any([full_path.endswith(x) for x in ext_blacklist]):
continue
if len(ext_filter) != 0: if len(ext_filter) != 0:
model_name, extension = os.path.splitext(file) model_name, extension = os.path.splitext(file)
if extension not in ext_filter: if extension not in ext_filter:

View File

@ -1,15 +1,19 @@
import ldm.modules.encoders.modules import ldm.modules.encoders.modules
import open_clip import open_clip
import torch import torch
import transformers.utils.hub
class DisableInitialization: class DisableInitialization:
""" """
When an object of this class enters a `with` block, it starts preventing torch's layer initialization When an object of this class enters a `with` block, it starts:
functions from working, and changes CLIP and OpenCLIP to not download model weights. When it leaves, - preventing torch's layer initialization functions from working
reverts everything to how it was. - changes CLIP and OpenCLIP to not download model weights
- changes CLIP to not make requests to check if there is a new version of a file you already have
Use like this: When it leaves the block, it reverts everything to how it was before.
Use it like this:
``` ```
with DisableInitialization(): with DisableInitialization():
do_things() do_things()
@ -26,19 +30,36 @@ class DisableInitialization:
def CLIPTextModel_from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs): def CLIPTextModel_from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs):
return self.CLIPTextModel_from_pretrained(None, *model_args, config=pretrained_model_name_or_path, state_dict={}, **kwargs) return self.CLIPTextModel_from_pretrained(None, *model_args, config=pretrained_model_name_or_path, state_dict={}, **kwargs)
def transformers_utils_hub_get_from_cache(url, *args, local_files_only=False, **kwargs):
# this file is always 404, prevent making request
if url == 'https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/added_tokens.json':
raise transformers.utils.hub.EntryNotFoundError
try:
return self.transformers_utils_hub_get_from_cache(url, *args, local_files_only=True, **kwargs)
except Exception as e:
return self.transformers_utils_hub_get_from_cache(url, *args, local_files_only=False, **kwargs)
self.init_kaiming_uniform = torch.nn.init.kaiming_uniform_ self.init_kaiming_uniform = torch.nn.init.kaiming_uniform_
self.init_no_grad_normal = torch.nn.init._no_grad_normal_ self.init_no_grad_normal = torch.nn.init._no_grad_normal_
self.init_no_grad_uniform_ = torch.nn.init._no_grad_uniform_
self.create_model_and_transforms = open_clip.create_model_and_transforms self.create_model_and_transforms = open_clip.create_model_and_transforms
self.CLIPTextModel_from_pretrained = ldm.modules.encoders.modules.CLIPTextModel.from_pretrained self.CLIPTextModel_from_pretrained = ldm.modules.encoders.modules.CLIPTextModel.from_pretrained
self.transformers_utils_hub_get_from_cache = transformers.utils.hub.get_from_cache
torch.nn.init.kaiming_uniform_ = do_nothing torch.nn.init.kaiming_uniform_ = do_nothing
torch.nn.init._no_grad_normal_ = do_nothing torch.nn.init._no_grad_normal_ = do_nothing
torch.nn.init._no_grad_uniform_ = do_nothing
open_clip.create_model_and_transforms = create_model_and_transforms_without_pretrained open_clip.create_model_and_transforms = create_model_and_transforms_without_pretrained
ldm.modules.encoders.modules.CLIPTextModel.from_pretrained = CLIPTextModel_from_pretrained ldm.modules.encoders.modules.CLIPTextModel.from_pretrained = CLIPTextModel_from_pretrained
transformers.utils.hub.get_from_cache = transformers_utils_hub_get_from_cache
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
torch.nn.init.kaiming_uniform_ = self.init_kaiming_uniform torch.nn.init.kaiming_uniform_ = self.init_kaiming_uniform
torch.nn.init._no_grad_normal_ = self.init_no_grad_normal torch.nn.init._no_grad_normal_ = self.init_no_grad_normal
torch.nn.init._no_grad_uniform_ = self.init_no_grad_uniform_
open_clip.create_model_and_transforms = self.create_model_and_transforms open_clip.create_model_and_transforms = self.create_model_and_transforms
ldm.modules.encoders.modules.CLIPTextModel.from_pretrained = self.CLIPTextModel_from_pretrained ldm.modules.encoders.modules.CLIPTextModel.from_pretrained = self.CLIPTextModel_from_pretrained
transformers.utils.hub.get_from_cache = self.transformers_utils_hub_get_from_cache

View File

@ -2,6 +2,7 @@ import collections
import os.path import os.path
import sys import sys
import gc import gc
import time
from collections import namedtuple from collections import namedtuple
import torch import torch
import re import re
@ -61,7 +62,7 @@ def find_checkpoint_config(info):
def list_models(): def list_models():
checkpoints_list.clear() checkpoints_list.clear()
model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"]) model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"], ext_blacklist=[".vae.safetensors"])
def modeltitle(path, shorthash): def modeltitle(path, shorthash):
abspath = os.path.abspath(path) abspath = os.path.abspath(path)
@ -288,6 +289,17 @@ def enable_midas_autodownload():
midas.api.load_model = load_model_wrapper midas.api.load_model = load_model_wrapper
class Timer:
def __init__(self):
self.start = time.time()
def elapsed(self):
end = time.time()
res = end - self.start
self.start = end
return res
def load_model(checkpoint_info=None): def load_model(checkpoint_info=None):
from modules import lowvram, sd_hijack from modules import lowvram, sd_hijack
checkpoint_info = checkpoint_info or select_checkpoint() checkpoint_info = checkpoint_info or select_checkpoint()
@ -319,11 +331,17 @@ def load_model(checkpoint_info=None):
if shared.cmd_opts.no_half: if shared.cmd_opts.no_half:
sd_config.model.params.unet_config.params.use_fp16 = False sd_config.model.params.unet_config.params.use_fp16 = False
timer = Timer()
with sd_disable_initialization.DisableInitialization(): with sd_disable_initialization.DisableInitialization():
sd_model = instantiate_from_config(sd_config.model) sd_model = instantiate_from_config(sd_config.model)
elapsed_create = timer.elapsed()
load_model_weights(sd_model, checkpoint_info) load_model_weights(sd_model, checkpoint_info)
elapsed_load_weights = timer.elapsed()
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram) lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram)
else: else:
@ -338,7 +356,9 @@ def load_model(checkpoint_info=None):
script_callbacks.model_loaded_callback(sd_model) script_callbacks.model_loaded_callback(sd_model)
print("Model loaded.") elapsed_the_rest = timer.elapsed()
print(f"Model loaded in {elapsed_create + elapsed_load_weights + elapsed_the_rest:.1f}s ({elapsed_create:.1f}s create model, {elapsed_load_weights:.1f}s load weights).")
return sd_model return sd_model
@ -349,7 +369,7 @@ def reload_model_weights(sd_model=None, info=None):
if not sd_model: if not sd_model:
sd_model = shared.sd_model sd_model = shared.sd_model
if sd_model is None: # previous model load failed if sd_model is None: # previous model load failed
current_checkpoint_info = None current_checkpoint_info = None
else: else:
current_checkpoint_info = sd_model.sd_checkpoint_info current_checkpoint_info = sd_model.sd_checkpoint_info
@ -371,6 +391,8 @@ def reload_model_weights(sd_model=None, info=None):
sd_hijack.model_hijack.undo_hijack(sd_model) sd_hijack.model_hijack.undo_hijack(sd_model)
timer = Timer()
try: try:
load_model_weights(sd_model, checkpoint_info) load_model_weights(sd_model, checkpoint_info)
except Exception as e: except Exception as e:
@ -384,6 +406,8 @@ def reload_model_weights(sd_model=None, info=None):
if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram: if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
sd_model.to(devices.device) sd_model.to(devices.device)
print("Weights loaded.") elapsed = timer.elapsed()
print(f"Weights loaded in {elapsed:.1f}s.")
return sd_model return sd_model