bring back short hashes to sd checkpoint selection
This commit is contained in:
parent
d1ea518dea
commit
c1928cdd61
|
@ -41,14 +41,16 @@ class CheckpointInfo:
|
||||||
if name.startswith("\\") or name.startswith("/"):
|
if name.startswith("\\") or name.startswith("/"):
|
||||||
name = name[1:]
|
name = name[1:]
|
||||||
|
|
||||||
self.title = name
|
self.name = name
|
||||||
self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]
|
self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]
|
||||||
self.hash = model_hash(filename)
|
self.hash = model_hash(filename)
|
||||||
|
|
||||||
self.sha256 = hashes.sha256_from_cache(self.filename, "checkpoint/" + self.title)
|
self.sha256 = hashes.sha256_from_cache(self.filename, "checkpoint/" + name)
|
||||||
self.shorthash = self.sha256[0:10] if self.sha256 else None
|
self.shorthash = self.sha256[0:10] if self.sha256 else None
|
||||||
|
|
||||||
self.ids = [self.hash, self.model_name, self.title, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256] if self.shorthash else [])
|
self.title = name if self.shorthash is None else f'{name} [{self.shorthash}]'
|
||||||
|
|
||||||
|
self.ids = [self.hash, self.model_name, self.title, name, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else [])
|
||||||
|
|
||||||
def register(self):
|
def register(self):
|
||||||
checkpoints_list[self.title] = self
|
checkpoints_list[self.title] = self
|
||||||
|
@ -56,13 +58,15 @@ class CheckpointInfo:
|
||||||
checkpoint_alisases[id] = self
|
checkpoint_alisases[id] = self
|
||||||
|
|
||||||
def calculate_shorthash(self):
|
def calculate_shorthash(self):
|
||||||
self.sha256 = hashes.sha256(self.filename, "checkpoint/" + self.title)
|
self.sha256 = hashes.sha256(self.filename, "checkpoint/" + self.name)
|
||||||
self.shorthash = self.sha256[0:10]
|
self.shorthash = self.sha256[0:10]
|
||||||
|
|
||||||
if self.shorthash not in self.ids:
|
if self.shorthash not in self.ids:
|
||||||
self.ids += [self.shorthash, self.sha256]
|
self.ids += [self.shorthash, self.sha256]
|
||||||
self.register()
|
self.register()
|
||||||
|
|
||||||
|
self.title = f'{self.name} [{self.shorthash}]'
|
||||||
|
|
||||||
return self.shorthash
|
return self.shorthash
|
||||||
|
|
||||||
|
|
||||||
|
@ -225,7 +229,10 @@ def read_state_dict(checkpoint_file, print_global_state=False, map_location=None
|
||||||
|
|
||||||
|
|
||||||
def load_model_weights(model, checkpoint_info: CheckpointInfo):
|
def load_model_weights(model, checkpoint_info: CheckpointInfo):
|
||||||
|
title = checkpoint_info.title
|
||||||
sd_model_hash = checkpoint_info.calculate_shorthash()
|
sd_model_hash = checkpoint_info.calculate_shorthash()
|
||||||
|
if checkpoint_info.title != title:
|
||||||
|
shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title
|
||||||
|
|
||||||
cache_enabled = shared.opts.sd_checkpoint_cache > 0
|
cache_enabled = shared.opts.sd_checkpoint_cache > 0
|
||||||
|
|
||||||
|
|
|
@ -439,7 +439,7 @@ def apply_setting(key, value):
|
||||||
opts.data_labels[key].onchange()
|
opts.data_labels[key].onchange()
|
||||||
|
|
||||||
opts.save(shared.config_filename)
|
opts.save(shared.config_filename)
|
||||||
return value
|
return getattr(opts, key)
|
||||||
|
|
||||||
|
|
||||||
def update_generation_info(generation_info, html_info, img_index):
|
def update_generation_info(generation_info, html_info, img_index):
|
||||||
|
@ -597,6 +597,16 @@ def ordered_ui_categories():
|
||||||
yield category
|
yield category
|
||||||
|
|
||||||
|
|
||||||
|
def get_value_for_setting(key):
|
||||||
|
value = getattr(opts, key)
|
||||||
|
|
||||||
|
info = opts.data_labels[key]
|
||||||
|
args = info.component_args() if callable(info.component_args) else info.component_args or {}
|
||||||
|
args = {k: v for k, v in args.items() if k not in {'precision'}}
|
||||||
|
|
||||||
|
return gr.update(value=value, **args)
|
||||||
|
|
||||||
|
|
||||||
def create_ui():
|
def create_ui():
|
||||||
import modules.img2img
|
import modules.img2img
|
||||||
import modules.txt2img
|
import modules.txt2img
|
||||||
|
@ -1600,7 +1610,7 @@ def create_ui():
|
||||||
|
|
||||||
opts.save(shared.config_filename)
|
opts.save(shared.config_filename)
|
||||||
|
|
||||||
return gr.update(value=value), opts.dumpjson()
|
return get_value_for_setting(key), opts.dumpjson()
|
||||||
|
|
||||||
with gr.Blocks(analytics_enabled=False) as settings_interface:
|
with gr.Blocks(analytics_enabled=False) as settings_interface:
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
|
@ -1771,15 +1781,6 @@ def create_ui():
|
||||||
|
|
||||||
component_keys = [k for k in opts.data_labels.keys() if k in component_dict]
|
component_keys = [k for k in opts.data_labels.keys() if k in component_dict]
|
||||||
|
|
||||||
def get_value_for_setting(key):
|
|
||||||
value = getattr(opts, key)
|
|
||||||
|
|
||||||
info = opts.data_labels[key]
|
|
||||||
args = info.component_args() if callable(info.component_args) else info.component_args or {}
|
|
||||||
args = {k: v for k, v in args.items() if k not in {'precision'}}
|
|
||||||
|
|
||||||
return gr.update(value=value, **args)
|
|
||||||
|
|
||||||
def get_settings_values():
|
def get_settings_values():
|
||||||
return [get_value_for_setting(key) for key in component_keys]
|
return [get_value_for_setting(key) for key in component_keys]
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user