From 4f5f78658eb29763b8777cbcef6a01e6f8be62b2 Mon Sep 17 00:00:00 2001 From: Connum Date: Tue, 27 Sep 2022 19:31:01 +0200 Subject: [PATCH] make UI restraints (currently sampling method only) more flexible and reusable across scripts --- javascript/ui.js | 50 +++++++++++++++++++++++++------------------ modules/scripts.py | 29 +++++++++++++++++++++++-- modules/ui.py | 6 +++--- scripts/img2imgalt.py | 6 ++++++ style.css | 4 ++++ 5 files changed, 69 insertions(+), 26 deletions(-) diff --git a/javascript/ui.js b/javascript/ui.js index 4d13b515..491eb8f9 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -185,38 +185,46 @@ onUiUpdate(function(){ }) /** - * force Euler method for the "img2img alternative test" script + * Implement script-dependent UI restraints, e.g. forcing a specific sampling method */ -let prev_sampling_method; -onUiTabChange(function() { - const currentTab = get_uiCurrentTab(); - if ( ! currentTab || currentTab?.textContent.trim() !== 'img2img' ) { +let prev_ui_states = {}; +function updateScriptRestraints() { + const currentTab = get_uiCurrentTab()?.textContent.trim(); + const restraintsField = Array.from(gradioApp().querySelectorAll(`#${currentTab}_script_restraints_json textarea`)) + .filter(el => uiElementIsVisible(el.closest('.gr-form')))?.[0]; + + if ( ! restraintsField ) { return; } - const altScriptName = 'img2img alternative test'; - const scriptSelect = gradioApp().querySelector('#component-223 select'); - const methodRadios = gradioApp().querySelectorAll('[name="radio-component-182"]'); - scriptSelect.addEventListener( 'change', function() { - if( scriptSelect.value === altScriptName) { - prev_sampling_method = gradioApp().querySelector('[name="radio-component-182"]:checked'); + if ( typeof prev_ui_states[currentTab] === 'undefined' ) { + prev_ui_states[currentTab] = {}; + } + + window.requestAnimationFrame(() => { + const restraints = JSON.parse(restraintsField.value); + // const scriptSelect = gradioApp().querySelector(`#${currentTab}_scripts select`); + const methodRadios = gradioApp().querySelectorAll(`[name="radio-${currentTab}_sampling"]`); + + if( restraints?.methods?.length ) { + prev_ui_states[currentTab].sampling_method = gradioApp().querySelector(`[name="radio-${currentTab}_sampling"]:checked`); methodRadios.forEach(radio => { - const isEuler = radio.value === 'Euler'; + const isAllowed = restraints.methods.includes(radio.value); const label = radio.closest('label'); - radio.disabled = !isEuler; - radio.checked = isEuler; - label.classList[isEuler ? 'remove' : 'add']('!cursor-not-allowed'); - label.title = !isEuler ? `${altScriptName} only works with the Euler method` : ''; + radio.disabled = !isAllowed; + radio.checked = isAllowed; + label.classList[isAllowed ? 'remove' : 'add']('!cursor-not-allowed','disabled'); + label.title = !isAllowed ? `The selected script does not work with this method` : ''; }); } else { - // reset to previous method + // reset to previously selected method methodRadios.forEach(radio => { const label = radio.closest('label'); radio.disabled = false; - radio.checked = radio === prev_sampling_method; - label.classList.remove('!cursor-not-allowed'); + radio.checked = radio === prev_ui_states[currentTab].sampling_method; + label.classList.remove('!cursor-not-allowed','disabled'); label.title = ''; }); } - }); -}) \ No newline at end of file + }) +} \ No newline at end of file diff --git a/modules/scripts.py b/modules/scripts.py index 202374e6..ce77aef1 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -1,6 +1,7 @@ import os import sys import traceback +import json import modules.ui as ui import gradio as gr @@ -24,6 +25,14 @@ class Script: def ui(self, is_img2img): pass + # Put restraints on UI elements when this script is selected. + # Restricting the available sampling methods: + # { + # "methods": [ "Euler", "DDIM" ] + # } + def ui_restraints(self): + return {} + # Determines when the script should be shown in the dropdown menu via the # returned value. As an example: # is_img2img is True if the current tab is img2img, and False if it is txt2img. @@ -106,7 +115,9 @@ class ScriptRunner: titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.scripts] - dropdown = gr.Dropdown(label="Script", choices=["None"] + titles, value="None", type="index") + id_prefix = "img2img_" if is_img2img else "txt2img_" + + dropdown = gr.Dropdown(label="Script", choices=["None"] + titles, value="None", type="index", elem_id=id_prefix+"scripts") inputs = [dropdown] for script in self.scripts: @@ -125,16 +136,23 @@ class ScriptRunner: inputs += controls script.args_to = len(inputs) + script_restraints_json = gr.Textbox(value="{}", elem_id=id_prefix+"script_restraints_json", show_label=False, visible=False) + inputs += [script_restraints_json]; + def select_script(script_index): if 0 < script_index <= len(self.scripts): script = self.scripts[script_index-1] args_from = script.args_from args_to = script.args_to else: + script = None args_from = 0 args_to = 0 - return [ui.gr_show(True if i == 0 else args_from <= i < args_to) for i in range(len(inputs))] + return ( + [ui.gr_show(True if i == 0 else args_from <= i < args_to) for i in range(len(inputs)-1)] + + [gr.Textbox.update(value=json.dumps(script.ui_restraints() if script is not None else {}), visible=False)] + ) dropdown.change( fn=select_script, @@ -142,6 +160,13 @@ class ScriptRunner: outputs=inputs ) + script_restraints_json.change( + _js="updateScriptRestraints", + fn=lambda: None, + inputs=[], + outputs=[] + ) + return inputs def run(self, p: StableDiffusionProcessing, *args): diff --git a/modules/ui.py b/modules/ui.py index 5476c32f..fe6f94e5 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -409,7 +409,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): with gr.Row().style(equal_height=False): with gr.Column(variant='panel'): - steps = gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=20) + steps = gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=20, elem_id="txt2img_steps") sampler_index = gr.Radio(label='Sampling method', elem_id="txt2img_sampling", choices=[x.name for x in samplers], value=samplers[0].name, type="index") with gr.Group(): @@ -588,8 +588,8 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): with gr.Row(): resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", show_label=False, choices=["Just resize", "Crop and resize", "Resize and fill"], type="index", value="Just resize") - steps = gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=20) - sampler_index = gr.Radio(label='Sampling method', choices=[x.name for x in samplers_for_img2img], value=samplers_for_img2img[0].name, type="index") + steps = gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=20, elem_id="img2img_steps") + sampler_index = gr.Radio(label='Sampling method', elem_id="img2img_sampling", choices=[x.name for x in samplers_for_img2img], value=samplers_for_img2img[0].name, type="index") with gr.Group(): width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512) diff --git a/scripts/img2imgalt.py b/scripts/img2imgalt.py index 0ef137f7..07dabdce 100644 --- a/scripts/img2imgalt.py +++ b/scripts/img2imgalt.py @@ -129,6 +129,12 @@ class Script(scripts.Script): sigma_adjustment = gr.Checkbox(label="Sigma adjustment for finding noise for image", value=False) return [original_prompt, original_negative_prompt, cfg, st, randomness, sigma_adjustment] + def ui_restraints(self): + restraints = { + "methods": ["Euler"] + } + return restraints + def run(self, p, original_prompt, original_negative_prompt, cfg, st, randomness, sigma_adjustment): p.batch_size = 1 p.batch_count = 1 diff --git a/style.css b/style.css index 4054e2df..438b8174 100644 --- a/style.css +++ b/style.css @@ -222,6 +222,10 @@ input[type="range"]{ margin: 0.5em 0 -0.3em 0; } +.gr-input-label.disabled { + opacity: 0.48; +} + #txt2img_sampling label{ padding-left: 0.6em; padding-right: 0.6em;