From e5707b66d6db2c019bfccf66f9ba53e3daaea40b Mon Sep 17 00:00:00 2001 From: Liam Date: Tue, 27 Sep 2022 19:29:53 -0400 Subject: [PATCH] switched the token counter to use hidden buttons instead of api call --- javascript/helpers.js | 13 ---------- javascript/ui.js | 60 +++++++++++-------------------------------- modules/sd_hijack.py | 3 +-- modules/ui.py | 13 +++++++--- 4 files changed, 25 insertions(+), 64 deletions(-) delete mode 100644 javascript/helpers.js diff --git a/javascript/helpers.js b/javascript/helpers.js deleted file mode 100644 index 1b26931f..00000000 --- a/javascript/helpers.js +++ /dev/null @@ -1,13 +0,0 @@ -// helper functions - -function debounce(func, wait_time) { - let timeout; - return function wrapped(...args) { - let call_function = () => { - clearTimeout(timeout); - func(...args) - } - clearTimeout(timeout); - timeout = setTimeout(call_function, wait_time); - }; -} \ No newline at end of file diff --git a/javascript/ui.js b/javascript/ui.js index fbe5a11d..6cfa5c08 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -182,51 +182,21 @@ onUiUpdate(function(){ }); json_elem.parentElement.style.display="none" - - let debounce_time = 800 - if (!txt2img_textarea) { - txt2img_textarea = gradioApp().querySelector("#txt2img_prompt > label > textarea") - txt2img_textarea?.addEventListener("input", debounce(submit_prompt_text.bind(null, "txt2img"), debounce_time)) - } - if (!img2img_textarea) { - img2img_textarea = gradioApp().querySelector("#img2img_prompt > label > textarea") - img2img_textarea?.addEventListener("input", debounce(submit_prompt_text.bind(null, "img2img"), debounce_time)) - } }) +let wait_time = 800 +let token_timeout; +function txt2img_token_counter(text) { + return update_token_counter("txt2img_token_button", text); +} -let txt2img_textarea, img2img_textarea = undefined; -function submit_prompt_text(source, e) { - let prompt_text; - if (source == "txt2img") - prompt_text = txt2img_textarea.value; - else if (source == "img2img") - prompt_text = img2img_textarea.value; - if (!prompt_text) - return; - params = { - method: "POST", - headers: { - "Accept": "application/json", - "Content-type": "application/json" - }, - body: JSON.stringify({data:[prompt_text]}) - } - fetch('http://127.0.0.1:7860/api/tokenize/', params) - .then((response) => response.json()) - .then((data) => { - if (data?.data.length) { - let response_json = data.data[0] - if (elem = gradioApp().getElementById(source+"_token_counter")) { - if (response_json.token_count > response_json.max_length) - elem.classList.add("red"); - else - elem.classList.remove("red"); - elem.innerText = response_json.token_count + "/" + response_json.max_length; - } - } - }) - .catch((error) => { - console.error('Error:', error); - }); -} \ No newline at end of file +function img2img_token_counter(text) { + return update_token_counter("img2img_token_button", text); +} + +function update_token_counter(button_id, text) { + if (token_timeout) + clearTimeout(token_timeout); + token_timeout = setTimeout(() => gradioApp().getElementById(button_id)?.click(), wait_time); + return []; +} diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 4d799ac0..bfbd07f9 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -273,8 +273,7 @@ class StableDiffusionModelHijack: def tokenize(self, text): max_length = self.clip.max_length - 2 _, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text]) - return {"tokens": remade_batch_tokens[0], "token_count":token_count, "max_length":max_length} - + return remade_batch_tokens[0], token_count, max_length class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): def __init__(self, wrapped, hijack): diff --git a/modules/ui.py b/modules/ui.py index 9a3d69c8..15bfd697 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -23,6 +23,7 @@ from modules.shared import opts, cmd_opts import modules.shared as shared from modules.sd_samplers import samplers, samplers_for_img2img from modules.sd_hijack import model_hijack +from modules.helpers import debounce import modules.ldsr_model import modules.scripts import modules.gfpgan_model @@ -330,6 +331,10 @@ def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: outputs=[seed, dummy_component] ) +def update_token_counter(text): + tokens, token_count, max_length = model_hijack.tokenize(text) + style_class = ' class="red"' if (token_count > max_length) else "" + return f"{token_count}/{max_length}" def create_toprow(is_img2img): id_part = "img2img" if is_img2img else "txt2img" @@ -339,15 +344,15 @@ def create_toprow(is_img2img): with gr.Row(): with gr.Column(scale=80): with gr.Row(): - prompt = gr.Textbox(label="Prompt", elem_id=id_part+"_prompt", show_label=False, placeholder="Prompt", lines=2) + prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, placeholder="Prompt", lines=2) + prompt.change(fn=lambda *args: [], _js=f"{id_part}_token_counter", inputs=[prompt], outputs=[], preprocess=False) with gr.Column(scale=1, elem_id="roll_col"): roll = gr.Button(value=art_symbol, elem_id="roll", visible=len(shared.artist_db.artists) > 0) paste = gr.Button(value=paste_symbol, elem_id="paste") token_counter = gr.HTML(value="", elem_id=f"{id_part}_token_counter") - token_output = gr.JSON(visible=False) - if is_img2img: # only define the api function ONCE - token_counter.change(fn=model_hijack.tokenize, api_name="tokenize", inputs=[token_counter], outputs=[token_output]) + hidden_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button") + hidden_button.click(fn=update_token_counter, inputs=[prompt], outputs=[token_counter]) with gr.Column(scale=10, elem_id="style_pos_col"): prompt_style = gr.Dropdown(label="Style 1", elem_id=f"{id_part}_style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())), visible=len(shared.prompt_styles.styles) > 1)