do not load aesthetic clip model until it's needed

add refresh button for aesthetic embeddings
add aesthetic params to images' infotext
This commit is contained in:
AUTOMATIC 2022-10-21 16:10:51 +03:00
parent 7d6b388d71
commit df57064093
8 changed files with 89 additions and 39 deletions

View File

@ -40,6 +40,8 @@ def iter_to_batched(iterable, n=1):
def create_ui(): def create_ui():
import modules.ui
with gr.Group(): with gr.Group():
with gr.Accordion("Open for Clip Aesthetic!", open=False): with gr.Accordion("Open for Clip Aesthetic!", open=False):
with gr.Row(): with gr.Row():
@ -55,6 +57,8 @@ def create_ui():
label="Aesthetic imgs embedding", label="Aesthetic imgs embedding",
value="None") value="None")
modules.ui.create_refresh_button(aesthetic_imgs, shared.update_aesthetic_embeddings, lambda: {"choices": sorted(shared.aesthetic_embeddings.keys())}, "refresh_aesthetic_embeddings")
with gr.Row(): with gr.Row():
aesthetic_imgs_text = gr.Textbox(label='Aesthetic text for imgs', aesthetic_imgs_text = gr.Textbox(label='Aesthetic text for imgs',
placeholder="This text is used to rotate the feature space of the imgs embs", placeholder="This text is used to rotate the feature space of the imgs embs",
@ -66,11 +70,21 @@ def create_ui():
return aesthetic_weight, aesthetic_steps, aesthetic_lr, aesthetic_slerp, aesthetic_imgs, aesthetic_imgs_text, aesthetic_slerp_angle, aesthetic_text_negative return aesthetic_weight, aesthetic_steps, aesthetic_lr, aesthetic_slerp, aesthetic_imgs, aesthetic_imgs_text, aesthetic_slerp_angle, aesthetic_text_negative
aesthetic_clip_model = None
def aesthetic_clip():
global aesthetic_clip_model
if aesthetic_clip_model is None or aesthetic_clip_model.name_or_path != shared.sd_model.cond_stage_model.wrapped.transformer.name_or_path:
aesthetic_clip_model = CLIPModel.from_pretrained(shared.sd_model.cond_stage_model.wrapped.transformer.name_or_path)
aesthetic_clip_model.cpu()
return aesthetic_clip_model
def generate_imgs_embd(name, folder, batch_size): def generate_imgs_embd(name, folder, batch_size):
# clipModel = CLIPModel.from_pretrained( model = aesthetic_clip().to(device)
# shared.sd_model.cond_stage_model.clipModel.name_or_path
# )
model = shared.clip_model.to(device)
processor = CLIPProcessor.from_pretrained(model.name_or_path) processor = CLIPProcessor.from_pretrained(model.name_or_path)
with torch.no_grad(): with torch.no_grad():
@ -91,7 +105,7 @@ def generate_imgs_embd(name, folder, batch_size):
path = str(Path(shared.cmd_opts.aesthetic_embeddings_dir) / f"{name}.pt") path = str(Path(shared.cmd_opts.aesthetic_embeddings_dir) / f"{name}.pt")
torch.save(embs, path) torch.save(embs, path)
model = model.cpu() model.cpu()
del processor del processor
del embs del embs
gc.collect() gc.collect()
@ -132,7 +146,7 @@ class AestheticCLIP:
self.image_embs = None self.image_embs = None
self.load_image_embs(None) self.load_image_embs(None)
def set_aesthetic_params(self, aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0, image_embs_name=None, def set_aesthetic_params(self, p, aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0, image_embs_name=None,
aesthetic_slerp=True, aesthetic_imgs_text="", aesthetic_slerp=True, aesthetic_imgs_text="",
aesthetic_slerp_angle=0.15, aesthetic_slerp_angle=0.15,
aesthetic_text_negative=False): aesthetic_text_negative=False):
@ -145,6 +159,18 @@ class AestheticCLIP:
self.aesthetic_steps = aesthetic_steps self.aesthetic_steps = aesthetic_steps
self.load_image_embs(image_embs_name) self.load_image_embs(image_embs_name)
if self.image_embs_name is not None:
p.extra_generation_params.update({
"Aesthetic LR": aesthetic_lr,
"Aesthetic weight": aesthetic_weight,
"Aesthetic steps": aesthetic_steps,
"Aesthetic embedding": self.image_embs_name,
"Aesthetic slerp": aesthetic_slerp,
"Aesthetic text": aesthetic_imgs_text,
"Aesthetic text negative": aesthetic_text_negative,
"Aesthetic slerp angle": aesthetic_slerp_angle,
})
def set_skip(self, skip): def set_skip(self, skip):
self.skip = skip self.skip = skip
@ -168,7 +194,7 @@ class AestheticCLIP:
tokens = torch.asarray(remade_batch_tokens).to(device) tokens = torch.asarray(remade_batch_tokens).to(device)
model = copy.deepcopy(shared.clip_model).to(device) model = copy.deepcopy(aesthetic_clip()).to(device)
model.requires_grad_(True) model.requires_grad_(True)
if self.aesthetic_imgs_text is not None and len(self.aesthetic_imgs_text) > 0: if self.aesthetic_imgs_text is not None and len(self.aesthetic_imgs_text) > 0:
text_embs_2 = model.get_text_features( text_embs_2 = model.get_text_features(

View File

@ -4,13 +4,22 @@ import gradio as gr
from modules.shared import script_path from modules.shared import script_path
from modules import shared from modules import shared
re_param_code = r"\s*([\w ]+):\s*([^,]+)(?:,|$)" re_param_code = r'\s*([\w ]+):\s*("(?:\\|\"|[^\"])+"|[^,]*)(?:,|$)'
re_param = re.compile(re_param_code) re_param = re.compile(re_param_code)
re_params = re.compile(r"^(?:" + re_param_code + "){3,}$") re_params = re.compile(r"^(?:" + re_param_code + "){3,}$")
re_imagesize = re.compile(r"^(\d+)x(\d+)$") re_imagesize = re.compile(r"^(\d+)x(\d+)$")
type_of_gr_update = type(gr.update()) type_of_gr_update = type(gr.update())
def quote(text):
if ',' not in str(text):
return text
text = str(text)
text = text.replace('\\', '\\\\')
text = text.replace('"', '\\"')
return f'"{text}"'
def parse_generation_parameters(x: str): def parse_generation_parameters(x: str):
"""parses generation parameters string, the one you see in text field under the picture in UI: """parses generation parameters string, the one you see in text field under the picture in UI:
``` ```
@ -83,7 +92,12 @@ def connect_paste(button, paste_fields, input_comp, js=None):
else: else:
try: try:
valtype = type(output.value) valtype = type(output.value)
if valtype == bool and v == "False":
val = False
else:
val = valtype(v) val = valtype(v)
res.append(gr.update(value=val)) res.append(gr.update(value=val))
except Exception: except Exception:
res.append(gr.update()) res.append(gr.update())

View File

@ -109,10 +109,7 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
inpainting_mask_invert=inpainting_mask_invert, inpainting_mask_invert=inpainting_mask_invert,
) )
shared.aesthetic_clip.set_aesthetic_params(float(aesthetic_lr), float(aesthetic_weight), int(aesthetic_steps), shared.aesthetic_clip.set_aesthetic_params(p, float(aesthetic_lr), float(aesthetic_weight), int(aesthetic_steps), aesthetic_imgs, aesthetic_slerp, aesthetic_imgs_text, aesthetic_slerp_angle, aesthetic_text_negative)
aesthetic_imgs, aesthetic_slerp, aesthetic_imgs_text,
aesthetic_slerp_angle,
aesthetic_text_negative)
if shared.cmd_opts.enable_console_prompts: if shared.cmd_opts.enable_console_prompts:
print(f"\nimg2img: {prompt}", file=shared.progress_print_out) print(f"\nimg2img: {prompt}", file=shared.progress_print_out)

View File

@ -12,7 +12,7 @@ from skimage import exposure
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
import modules.sd_hijack import modules.sd_hijack
from modules import devices, prompt_parser, masking, sd_samplers, lowvram from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste
from modules.sd_hijack import model_hijack from modules.sd_hijack import model_hijack
from modules.shared import opts, cmd_opts, state from modules.shared import opts, cmd_opts, state
import modules.shared as shared import modules.shared as shared
@ -318,7 +318,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
generation_params.update(p.extra_generation_params) generation_params.update(p.extra_generation_params)
generation_params_text = ", ".join([k if k == v else f'{k}: {v}' for k, v in generation_params.items() if v is not None]) generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None])
negative_prompt_text = "\nNegative prompt: " + p.negative_prompt if p.negative_prompt else "" negative_prompt_text = "\nNegative prompt: " + p.negative_prompt if p.negative_prompt else ""

View File

@ -234,9 +234,6 @@ def load_model(checkpoint_info=None):
sd_hijack.model_hijack.hijack(sd_model) sd_hijack.model_hijack.hijack(sd_model)
if shared.clip_model is None or shared.clip_model.transformer.name_or_path != sd_model.cond_stage_model.wrapped.transformer.name_or_path:
shared.clip_model = CLIPModel.from_pretrained(sd_model.cond_stage_model.wrapped.transformer.name_or_path)
sd_model.eval() sd_model.eval()
print(f"Model loaded.") print(f"Model loaded.")

View File

@ -36,9 +36,7 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
firstphase_height=firstphase_height if enable_hr else None, firstphase_height=firstphase_height if enable_hr else None,
) )
shared.aesthetic_clip.set_aesthetic_params(float(aesthetic_lr), float(aesthetic_weight), int(aesthetic_steps), shared.aesthetic_clip.set_aesthetic_params(p, float(aesthetic_lr), float(aesthetic_weight), int(aesthetic_steps), aesthetic_imgs, aesthetic_slerp, aesthetic_imgs_text, aesthetic_slerp_angle, aesthetic_text_negative)
aesthetic_imgs, aesthetic_slerp, aesthetic_imgs_text, aesthetic_slerp_angle,
aesthetic_text_negative)
if cmd_opts.enable_console_prompts: if cmd_opts.enable_console_prompts:
print(f"\ntxt2img: {prompt}", file=shared.progress_print_out) print(f"\ntxt2img: {prompt}", file=shared.progress_print_out)

View File

@ -597,11 +597,7 @@ def apply_setting(key, value):
return value return value
def create_ui(wrap_gradio_gpu_call): def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
import modules.img2img
import modules.txt2img
def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
def refresh(): def refresh():
refresh_method() refresh_method()
args = refreshed_args() if callable(refreshed_args) else refreshed_args args = refreshed_args() if callable(refreshed_args) else refreshed_args
@ -613,12 +609,18 @@ def create_ui(wrap_gradio_gpu_call):
refresh_button = gr.Button(value=refresh_symbol, elem_id=elem_id) refresh_button = gr.Button(value=refresh_symbol, elem_id=elem_id)
refresh_button.click( refresh_button.click(
fn = refresh, fn=refresh,
inputs = [], inputs=[],
outputs = [refresh_component] outputs=[refresh_component]
) )
return refresh_button return refresh_button
def create_ui(wrap_gradio_gpu_call):
import modules.img2img
import modules.txt2img
with gr.Blocks(analytics_enabled=False) as txt2img_interface: with gr.Blocks(analytics_enabled=False) as txt2img_interface:
txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False) txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False)
dummy_component = gr.Label(visible=False) dummy_component = gr.Label(visible=False)
@ -802,6 +804,14 @@ def create_ui(wrap_gradio_gpu_call):
(hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)), (hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)),
(firstphase_width, "First pass size-1"), (firstphase_width, "First pass size-1"),
(firstphase_height, "First pass size-2"), (firstphase_height, "First pass size-2"),
(aesthetic_lr, "Aesthetic LR"),
(aesthetic_weight, "Aesthetic weight"),
(aesthetic_steps, "Aesthetic steps"),
(aesthetic_imgs, "Aesthetic embedding"),
(aesthetic_slerp, "Aesthetic slerp"),
(aesthetic_imgs_text, "Aesthetic text"),
(aesthetic_text_negative, "Aesthetic text negative"),
(aesthetic_slerp_angle, "Aesthetic slerp angle"),
] ]
txt2img_preview_params = [ txt2img_preview_params = [
@ -1077,6 +1087,14 @@ def create_ui(wrap_gradio_gpu_call):
(seed_resize_from_w, "Seed resize from-1"), (seed_resize_from_w, "Seed resize from-1"),
(seed_resize_from_h, "Seed resize from-2"), (seed_resize_from_h, "Seed resize from-2"),
(denoising_strength, "Denoising strength"), (denoising_strength, "Denoising strength"),
(aesthetic_lr_im, "Aesthetic LR"),
(aesthetic_weight_im, "Aesthetic weight"),
(aesthetic_steps_im, "Aesthetic steps"),
(aesthetic_imgs_im, "Aesthetic embedding"),
(aesthetic_slerp_im, "Aesthetic slerp"),
(aesthetic_imgs_text_im, "Aesthetic text"),
(aesthetic_text_negative_im, "Aesthetic text negative"),
(aesthetic_slerp_angle_im, "Aesthetic slerp angle"),
] ]
token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter]) token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter])

View File

@ -477,7 +477,7 @@ input[type="range"]{
padding: 0; padding: 0;
} }
#refresh_sd_model_checkpoint, #refresh_sd_hypernetwork, #refresh_train_hypernetwork_name, #refresh_train_embedding_name, #refresh_localization{ #refresh_sd_model_checkpoint, #refresh_sd_hypernetwork, #refresh_train_hypernetwork_name, #refresh_train_embedding_name, #refresh_localization, #refresh_aesthetic_embeddings{
max-width: 2.5em; max-width: 2.5em;
min-width: 2.5em; min-width: 2.5em;
height: 2.4em; height: 2.4em;