switched the token counter to use hidden buttons instead of api call

This commit is contained in:
Liam 2022-09-27 19:29:53 -04:00
parent 981fe9c4a3
commit e5707b66d6
4 changed files with 25 additions and 64 deletions

View File

@ -1,13 +0,0 @@
// helper functions
function debounce(func, wait_time) {
let timeout;
return function wrapped(...args) {
let call_function = () => {
clearTimeout(timeout);
func(...args)
}
clearTimeout(timeout);
timeout = setTimeout(call_function, wait_time);
};
}

View File

@ -182,51 +182,21 @@ onUiUpdate(function(){
}); });
json_elem.parentElement.style.display="none" json_elem.parentElement.style.display="none"
let debounce_time = 800
if (!txt2img_textarea) {
txt2img_textarea = gradioApp().querySelector("#txt2img_prompt > label > textarea")
txt2img_textarea?.addEventListener("input", debounce(submit_prompt_text.bind(null, "txt2img"), debounce_time))
}
if (!img2img_textarea) {
img2img_textarea = gradioApp().querySelector("#img2img_prompt > label > textarea")
img2img_textarea?.addEventListener("input", debounce(submit_prompt_text.bind(null, "img2img"), debounce_time))
}
}) })
let wait_time = 800
let token_timeout;
function txt2img_token_counter(text) {
return update_token_counter("txt2img_token_button", text);
}
let txt2img_textarea, img2img_textarea = undefined; function img2img_token_counter(text) {
function submit_prompt_text(source, e) { return update_token_counter("img2img_token_button", text);
let prompt_text; }
if (source == "txt2img")
prompt_text = txt2img_textarea.value; function update_token_counter(button_id, text) {
else if (source == "img2img") if (token_timeout)
prompt_text = img2img_textarea.value; clearTimeout(token_timeout);
if (!prompt_text) token_timeout = setTimeout(() => gradioApp().getElementById(button_id)?.click(), wait_time);
return; return [];
params = { }
method: "POST",
headers: {
"Accept": "application/json",
"Content-type": "application/json"
},
body: JSON.stringify({data:[prompt_text]})
}
fetch('http://127.0.0.1:7860/api/tokenize/', params)
.then((response) => response.json())
.then((data) => {
if (data?.data.length) {
let response_json = data.data[0]
if (elem = gradioApp().getElementById(source+"_token_counter")) {
if (response_json.token_count > response_json.max_length)
elem.classList.add("red");
else
elem.classList.remove("red");
elem.innerText = response_json.token_count + "/" + response_json.max_length;
}
}
})
.catch((error) => {
console.error('Error:', error);
});
}

View File

@ -273,8 +273,7 @@ class StableDiffusionModelHijack:
def tokenize(self, text): def tokenize(self, text):
max_length = self.clip.max_length - 2 max_length = self.clip.max_length - 2
_, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text]) _, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text])
return {"tokens": remade_batch_tokens[0], "token_count":token_count, "max_length":max_length} return remade_batch_tokens[0], token_count, max_length
class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
def __init__(self, wrapped, hijack): def __init__(self, wrapped, hijack):

View File

@ -23,6 +23,7 @@ from modules.shared import opts, cmd_opts
import modules.shared as shared import modules.shared as shared
from modules.sd_samplers import samplers, samplers_for_img2img from modules.sd_samplers import samplers, samplers_for_img2img
from modules.sd_hijack import model_hijack from modules.sd_hijack import model_hijack
from modules.helpers import debounce
import modules.ldsr_model import modules.ldsr_model
import modules.scripts import modules.scripts
import modules.gfpgan_model import modules.gfpgan_model
@ -330,6 +331,10 @@ def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info:
outputs=[seed, dummy_component] outputs=[seed, dummy_component]
) )
def update_token_counter(text):
tokens, token_count, max_length = model_hijack.tokenize(text)
style_class = ' class="red"' if (token_count > max_length) else ""
return f"<span {style_class}>{token_count}/{max_length}</span>"
def create_toprow(is_img2img): def create_toprow(is_img2img):
id_part = "img2img" if is_img2img else "txt2img" id_part = "img2img" if is_img2img else "txt2img"
@ -339,15 +344,15 @@ def create_toprow(is_img2img):
with gr.Row(): with gr.Row():
with gr.Column(scale=80): with gr.Column(scale=80):
with gr.Row(): with gr.Row():
prompt = gr.Textbox(label="Prompt", elem_id=id_part+"_prompt", show_label=False, placeholder="Prompt", lines=2) prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, placeholder="Prompt", lines=2)
prompt.change(fn=lambda *args: [], _js=f"{id_part}_token_counter", inputs=[prompt], outputs=[], preprocess=False)
with gr.Column(scale=1, elem_id="roll_col"): with gr.Column(scale=1, elem_id="roll_col"):
roll = gr.Button(value=art_symbol, elem_id="roll", visible=len(shared.artist_db.artists) > 0) roll = gr.Button(value=art_symbol, elem_id="roll", visible=len(shared.artist_db.artists) > 0)
paste = gr.Button(value=paste_symbol, elem_id="paste") paste = gr.Button(value=paste_symbol, elem_id="paste")
token_counter = gr.HTML(value="<span></span>", elem_id=f"{id_part}_token_counter") token_counter = gr.HTML(value="<span></span>", elem_id=f"{id_part}_token_counter")
token_output = gr.JSON(visible=False) hidden_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button")
if is_img2img: # only define the api function ONCE hidden_button.click(fn=update_token_counter, inputs=[prompt], outputs=[token_counter])
token_counter.change(fn=model_hijack.tokenize, api_name="tokenize", inputs=[token_counter], outputs=[token_output])
with gr.Column(scale=10, elem_id="style_pos_col"): with gr.Column(scale=10, elem_id="style_pos_col"):
prompt_style = gr.Dropdown(label="Style 1", elem_id=f"{id_part}_style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())), visible=len(shared.prompt_styles.styles) > 1) prompt_style = gr.Dropdown(label="Style 1", elem_id=f"{id_part}_style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())), visible=len(shared.prompt_styles.styles) > 1)