X/Y plot support for switching checkpoints.

This commit is contained in:
AUTOMATIC 2022-09-17 13:49:36 +03:00
parent 99585b3514
commit 304222ef94
3 changed files with 19 additions and 2 deletions

View File

@ -127,9 +127,9 @@ def load_model():
return sd_model return sd_model
def reload_model_weights(sd_model): def reload_model_weights(sd_model, info=None):
from modules import lowvram, devices from modules import lowvram, devices
checkpoint_info = select_checkpoint() checkpoint_info = info or select_checkpoint()
if sd_model.sd_model_checkpint == checkpoint_info.filename: if sd_model.sd_model_checkpint == checkpoint_info.filename:
return return

View File

@ -66,6 +66,8 @@ titles = {
"Style 2": "Style to apply; styles have components for both positive and negative prompts and apply to both", "Style 2": "Style to apply; styles have components for both positive and negative prompts and apply to both",
"Apply style": "Insert selected styles into prompt fields", "Apply style": "Insert selected styles into prompt fields",
"Create style": "Save current prompts as a style. If you add the token {prompt} to the text, the style use that as placeholder for your prompt when you use the style in the future.", "Create style": "Save current prompts as a style. If you add the token {prompt} to the text, the style use that as placeholder for your prompt when you use the style in the future.",
"Checkpoint name": "Loads weights from checkpoint before making images. You can either use hash or a part of filename (as seen in settings) for checkpoint name. Recommended to use with Y axis for less switching.",
} }
function gradioApp(){ function gradioApp(){

View File

@ -10,7 +10,9 @@ import gradio as gr
from modules import images from modules import images
from modules.processing import process_images, Processed from modules.processing import process_images, Processed
from modules.shared import opts, cmd_opts, state from modules.shared import opts, cmd_opts, state
import modules.shared as shared
import modules.sd_samplers import modules.sd_samplers
import modules.sd_models
import re import re
@ -41,6 +43,15 @@ def apply_sampler(p, x, xs):
p.sampler_index = sampler_index p.sampler_index = sampler_index
def apply_checkpoint(p, x, xs):
applicable = [info for info in modules.sd_models.checkpoints_list.values() if x in info.title]
assert len(applicable) > 0, f'Checkpoint {x} for found'
info = applicable[0]
modules.sd_models.reload_model_weights(shared.sd_model, info)
def format_value_add_label(p, opt, x): def format_value_add_label(p, opt, x):
if type(x) == float: if type(x) == float:
x = round(x, 8) x = round(x, 8)
@ -74,6 +85,7 @@ axis_options = [
AxisOption("CFG Scale", float, apply_field("cfg_scale"), format_value_add_label), AxisOption("CFG Scale", float, apply_field("cfg_scale"), format_value_add_label),
AxisOption("Prompt S/R", str, apply_prompt, format_value), AxisOption("Prompt S/R", str, apply_prompt, format_value),
AxisOption("Sampler", str, apply_sampler, format_value), AxisOption("Sampler", str, apply_sampler, format_value),
AxisOption("Checkpoint name", str, apply_checkpoint, format_value),
AxisOptionImg2Img("Denoising", float, apply_field("denoising_strength"), format_value_add_label), # as it is now all AxisOptionImg2Img items must go after AxisOption ones AxisOptionImg2Img("Denoising", float, apply_field("denoising_strength"), format_value_add_label), # as it is now all AxisOptionImg2Img items must go after AxisOption ones
] ]
@ -215,4 +227,7 @@ class Script(scripts.Script):
if opts.grid_save: if opts.grid_save:
images.save_image(processed.images[0], p.outpath_grids, "xy_grid", prompt=p.prompt, seed=processed.seed, grid=True, p=p) images.save_image(processed.images[0], p.outpath_grids, "xy_grid", prompt=p.prompt, seed=processed.seed, grid=True, p=p)
# restore checkpoint in case it was changed by axes
modules.sd_models.reload_model_weights(shared.sd_model)
return processed return processed