diff --git a/modules/sd_models.py b/modules/sd_models.py index dc81b0dc..9decc911 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -10,7 +10,7 @@ from ldm.util import instantiate_from_config from modules import shared -CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash']) +CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name']) checkpoints_list = {} try: @@ -45,7 +45,8 @@ def list_models(): if os.path.exists(cmd_ckpt): h = model_hash(cmd_ckpt) title = modeltitle(cmd_ckpt, h) - checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h) + model_name = title.rsplit(".",1)[0] # remove extension if present + checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, model_name) elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file: print(f"Checkpoint in --ckpt argument not found: {cmd_ckpt}", file=sys.stderr) @@ -53,7 +54,8 @@ def list_models(): for filename in glob.glob(model_dir + '/**/*.ckpt', recursive=True): h = model_hash(filename) title = modeltitle(filename, h) - checkpoints_list[title] = CheckpointInfo(filename, title, h) + model_name = title.rsplit(".",1)[0] # remove extension if present + checkpoints_list[title] = CheckpointInfo(filename, title, h, model_name) def model_hash(filename):