find configs for models at runtime rather than when starting

This commit is contained in:
AUTOMATIC 2023-01-04 12:47:42 +03:00
parent 02d7abf514
commit 8d8a05a3bb
2 changed files with 22 additions and 14 deletions

View File

@ -97,8 +97,11 @@ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=F
def should_hijack_inpainting(checkpoint_info): def should_hijack_inpainting(checkpoint_info):
from modules import sd_models
ckpt_basename = os.path.basename(checkpoint_info.filename).lower() ckpt_basename = os.path.basename(checkpoint_info.filename).lower()
cfg_basename = os.path.basename(checkpoint_info.config).lower() cfg_basename = os.path.basename(sd_models.find_checkpoint_config(checkpoint_info)).lower()
return "inpainting" in ckpt_basename and not "inpainting" in cfg_basename return "inpainting" in ckpt_basename and not "inpainting" in cfg_basename

View File

@ -20,7 +20,7 @@ from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inp
model_dir = "Stable-diffusion" model_dir = "Stable-diffusion"
model_path = os.path.abspath(os.path.join(models_path, model_dir)) model_path = os.path.abspath(os.path.join(models_path, model_dir))
CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name', 'config']) CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name'])
checkpoints_list = {} checkpoints_list = {}
checkpoints_loaded = collections.OrderedDict() checkpoints_loaded = collections.OrderedDict()
@ -48,6 +48,14 @@ def checkpoint_tiles():
return sorted([x.title for x in checkpoints_list.values()], key = alphanumeric_key) return sorted([x.title for x in checkpoints_list.values()], key = alphanumeric_key)
def find_checkpoint_config(info):
config = os.path.splitext(info.filename)[0] + ".yaml"
if os.path.exists(config):
return config
return shared.cmd_opts.config
def list_models(): def list_models():
checkpoints_list.clear() checkpoints_list.clear()
model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"]) model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"])
@ -73,7 +81,7 @@ def list_models():
if os.path.exists(cmd_ckpt): if os.path.exists(cmd_ckpt):
h = model_hash(cmd_ckpt) h = model_hash(cmd_ckpt)
title, short_model_name = modeltitle(cmd_ckpt, h) title, short_model_name = modeltitle(cmd_ckpt, h)
checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name, shared.cmd_opts.config) checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name)
shared.opts.data['sd_model_checkpoint'] = title shared.opts.data['sd_model_checkpoint'] = title
elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file: elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr) print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr)
@ -81,12 +89,7 @@ def list_models():
h = model_hash(filename) h = model_hash(filename)
title, short_model_name = modeltitle(filename, h) title, short_model_name = modeltitle(filename, h)
basename, _ = os.path.splitext(filename) checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name)
config = basename + ".yaml"
if not os.path.exists(config):
config = shared.cmd_opts.config
checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name, config)
def get_closet_checkpoint_match(searchString): def get_closet_checkpoint_match(searchString):
@ -282,9 +285,10 @@ def enable_midas_autodownload():
def load_model(checkpoint_info=None): def load_model(checkpoint_info=None):
from modules import lowvram, sd_hijack from modules import lowvram, sd_hijack
checkpoint_info = checkpoint_info or select_checkpoint() checkpoint_info = checkpoint_info or select_checkpoint()
checkpoint_config = find_checkpoint_config(checkpoint_info)
if checkpoint_info.config != shared.cmd_opts.config: if checkpoint_config != shared.cmd_opts.config:
print(f"Loading config from: {checkpoint_info.config}") print(f"Loading config from: {checkpoint_config}")
if shared.sd_model: if shared.sd_model:
sd_hijack.model_hijack.undo_hijack(shared.sd_model) sd_hijack.model_hijack.undo_hijack(shared.sd_model)
@ -292,7 +296,7 @@ def load_model(checkpoint_info=None):
gc.collect() gc.collect()
devices.torch_gc() devices.torch_gc()
sd_config = OmegaConf.load(checkpoint_info.config) sd_config = OmegaConf.load(checkpoint_config)
if should_hijack_inpainting(checkpoint_info): if should_hijack_inpainting(checkpoint_info):
# Hardcoded config for now... # Hardcoded config for now...
@ -302,7 +306,7 @@ def load_model(checkpoint_info=None):
sd_config.model.params.finetune_keys = None sd_config.model.params.finetune_keys = None
# Create a "fake" config with a different name so that we know to unload it when switching models. # Create a "fake" config with a different name so that we know to unload it when switching models.
checkpoint_info = checkpoint_info._replace(config=checkpoint_info.config.replace(".yaml", "-inpainting.yaml")) checkpoint_info = checkpoint_info._replace(config=checkpoint_config.replace(".yaml", "-inpainting.yaml"))
if not hasattr(sd_config.model.params, "use_ema"): if not hasattr(sd_config.model.params, "use_ema"):
sd_config.model.params.use_ema = False sd_config.model.params.use_ema = False
@ -343,11 +347,12 @@ def reload_model_weights(sd_model=None, info=None):
sd_model = shared.sd_model sd_model = shared.sd_model
current_checkpoint_info = sd_model.sd_checkpoint_info current_checkpoint_info = sd_model.sd_checkpoint_info
checkpoint_config = find_checkpoint_config(current_checkpoint_info)
if sd_model.sd_model_checkpoint == checkpoint_info.filename: if sd_model.sd_model_checkpoint == checkpoint_info.filename:
return return
if sd_model.sd_checkpoint_info.config != checkpoint_info.config or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info): if checkpoint_config != find_checkpoint_config(checkpoint_info) or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info):
del sd_model del sd_model
checkpoints_loaded.clear() checkpoints_loaded.clear()
load_model(checkpoint_info) load_model(checkpoint_info)