From 050a6a798cec90ae2f881c2ddd3f0221e69907dc Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 8 Oct 2022 23:26:48 +0300 Subject: [PATCH] support loading .yaml config with same name as model support EMA weights in processing (????) --- modules/processing.py | 2 +- modules/sd_models.py | 30 +++++++++++++++++++++++------- 2 files changed, 24 insertions(+), 8 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index 31220881..4fea6d56 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -347,7 +347,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed: infotexts = [] output_images = [] - with torch.no_grad(): + with torch.no_grad(), p.sd_model.ema_scope(): with devices.autocast(): p.init(all_prompts, all_seeds, all_subseeds) diff --git a/modules/sd_models.py b/modules/sd_models.py index a09866ce..cb3982b1 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -14,7 +14,7 @@ from modules.paths import models_path model_dir = "Stable-diffusion" model_path = os.path.abspath(os.path.join(models_path, model_dir)) -CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name']) +CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name', 'config']) checkpoints_list = {} try: @@ -63,14 +63,20 @@ def list_models(): if os.path.exists(cmd_ckpt): h = model_hash(cmd_ckpt) title, short_model_name = modeltitle(cmd_ckpt, h) - checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name) + checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name, shared.cmd_opts.config) shared.opts.data['sd_model_checkpoint'] = title elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file: print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr) for filename in model_list: h = model_hash(filename) title, short_model_name = modeltitle(filename, h) - checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name) + + basename, _ = os.path.splitext(filename) + config = basename + ".yaml" + if not os.path.exists(config): + config = shared.cmd_opts.config + + checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name, config) def get_closet_checkpoint_match(searchString): @@ -116,7 +122,10 @@ def select_checkpoint(): return checkpoint_info -def load_model_weights(model, checkpoint_file, sd_model_hash): +def load_model_weights(model, checkpoint_info): + checkpoint_file = checkpoint_info.filename + sd_model_hash = checkpoint_info.hash + print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}") pl_sd = torch.load(checkpoint_file, map_location="cpu") @@ -148,15 +157,19 @@ def load_model_weights(model, checkpoint_file, sd_model_hash): model.sd_model_hash = sd_model_hash model.sd_model_checkpoint = checkpoint_file + model.sd_checkpoint_info = checkpoint_info def load_model(): from modules import lowvram, sd_hijack checkpoint_info = select_checkpoint() - sd_config = OmegaConf.load(shared.cmd_opts.config) + if checkpoint_info.config != shared.cmd_opts.config: + print(f"Loading config from: {shared.cmd_opts.config}") + + sd_config = OmegaConf.load(checkpoint_info.config) sd_model = instantiate_from_config(sd_config.model) - load_model_weights(sd_model, checkpoint_info.filename, checkpoint_info.hash) + load_model_weights(sd_model, checkpoint_info) if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram) @@ -178,6 +191,9 @@ def reload_model_weights(sd_model, info=None): if sd_model.sd_model_checkpoint == checkpoint_info.filename: return + if sd_model.sd_checkpoint_info.config != checkpoint_info.config: + return load_model() + if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: lowvram.send_everything_to_cpu() else: @@ -185,7 +201,7 @@ def reload_model_weights(sd_model, info=None): sd_hijack.model_hijack.undo_hijack(sd_model) - load_model_weights(sd_model, checkpoint_info.filename, checkpoint_info.hash) + load_model_weights(sd_model, checkpoint_info) sd_hijack.model_hijack.hijack(sd_model)