From c091cf1b4acd2047644d3571bcbfd81c81b4c3af Mon Sep 17 00:00:00 2001 From: Nick Petalas Date: Thu, 22 Dec 2022 22:28:10 +0000 Subject: [PATCH 001/151] upgrading torch, torchvision, xformers (windows) to use u117 --- launch.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/launch.py b/launch.py index 715427fd..f51f23f7 100644 --- a/launch.py +++ b/launch.py @@ -179,7 +179,7 @@ def run_extensions_installers(settings_file): def prepare_environment(): global skip_install - torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113") + torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 --extra-index-url https://download.pytorch.org/whl/cu117") requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt") commandline_args = os.environ.get('COMMANDLINE_ARGS', "") @@ -187,8 +187,6 @@ def prepare_environment(): clip_package = os.environ.get('CLIP_PACKAGE', "git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1") openclip_package = os.environ.get('OPENCLIP_PACKAGE', "git+https://github.com/mlfoundations/open_clip.git@bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b") - xformers_windows_package = os.environ.get('XFORMERS_WINDOWS_PACKAGE', 'https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl') - stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git") taming_transformers_repo = os.environ.get('TAMING_TRANSFORMERS_REPO', "https://github.com/CompVis/taming-transformers.git") k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git') @@ -239,7 +237,7 @@ def prepare_environment(): if (not is_installed("xformers") or reinstall_xformers) and xformers: if platform.system() == "Windows": if platform.python_version().startswith("3.10"): - run_pip(f"install -U -I --no-deps {xformers_windows_package}", "xformers") + run_pip(f"install -U -I --no-deps xformers==0.0.16rc425", "xformers") else: print("Installation of xformers is not supported in this version of Python.") print("You can also check this and build manually: https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers#building-xformers-on-windows-by-duckness") From 3262e825cc542ff634e6ba2e3a162eafdc6c1bba Mon Sep 17 00:00:00 2001 From: Takuma Mori Date: Sat, 21 Jan 2023 17:42:04 +0900 Subject: [PATCH 002/151] add --xformers-flash-attention option & impl --- modules/sd_hijack_optimizations.py | 26 ++++++++++++++++++++++++-- modules/shared.py | 1 + 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 4fa54329..9967359b 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -290,7 +290,19 @@ def xformers_attention_forward(self, x, context=None, mask=None): q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in)) del q_in, k_in, v_in - out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) + + if shared.cmd_opts.xformers_flash_attention: + op = xformers.ops.MemoryEfficientAttentionFlashAttentionOp + fw, bw = op + if not fw.supports(xformers.ops.fmha.Inputs(query=q, key=k, value=v, attn_bias=None)): + # print('xformers_attention_forward', q.shape, k.shape, v.shape) + # Flash Attention is not availabe for the input arguments. + # Fallback to default xFormers' backend. + op = None + else: + op = None + + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=op) out = rearrange(out, 'b n h d -> b n (h d)', h=h) return self.to_out(out) @@ -365,7 +377,17 @@ def xformers_attnblock_forward(self, x): q = q.contiguous() k = k.contiguous() v = v.contiguous() - out = xformers.ops.memory_efficient_attention(q, k, v) + if shared.cmd_opts.xformers_flash_attention: + op = xformers.ops.MemoryEfficientAttentionFlashAttentionOp + fw, bw = op + if not fw.supports(xformers.ops.fmha.Inputs(query=q, key=k, value=v)): + # print('xformers_attnblock_forward', q.shape, k.shape, v.shape) + # Flash Attention is not availabe for the input arguments. + # Fallback to default xFormers' backend. + op = None + else: + op = None + out = xformers.ops.memory_efficient_attention(q, k, v, op=op) out = rearrange(out, 'b (h w) c -> b c h w', h=h) out = self.proj_out(out) return x + out diff --git a/modules/shared.py b/modules/shared.py index 72fb1934..23328adf 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -57,6 +57,7 @@ parser.add_argument("--realesrgan-models-path", type=str, help="Path to director parser.add_argument("--clip-models-path", type=str, help="Path to directory with CLIP model file(s).", default=None) parser.add_argument("--xformers", action='store_true', help="enable xformers for cross attention layers") parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work") +parser.add_argument("--xformers-flash-attention", action='store_true', help="enable xformers with Flash Attention to improve reproducibility (supported for SD2.x or variant only)") parser.add_argument("--deepdanbooru", action='store_true', help="does not do anything") parser.add_argument("--opt-split-attention", action='store_true', help="force-enables Doggettx's cross-attention layer optimization. By default, it's on for torch cuda.") parser.add_argument("--opt-sub-quad-attention", action='store_true', help="enable memory efficient sub-quadratic cross-attention layer optimization") From bf457b30fbfedb4b6eb2a198cbaa9f2ba071d31f Mon Sep 17 00:00:00 2001 From: EllangoK Date: Sat, 21 Jan 2023 16:21:33 -0500 Subject: [PATCH 003/151] compact checkbox and fix copy image btn overflow also fixes type for #tab_extensions in style.css --- modules/ui.py | 2 +- style.css | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/modules/ui.py b/modules/ui.py index af6dfb21..12fc9e6a 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -919,7 +919,7 @@ def create_ui(): seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('img2img') elif category == "checkboxes": - with FormRow(elem_id="img2img_checkboxes"): + with FormRow(elem_id="img2img_checkboxes", variant="compact"): restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="img2img_restore_faces") tiling = gr.Checkbox(label='Tiling', value=False, elem_id="img2img_tiling") diff --git a/style.css b/style.css index 507acec1..b215405d 100644 --- a/style.css +++ b/style.css @@ -589,7 +589,7 @@ canvas[key="mask"] { /* Extensions */ -#tab_extensions table``{ +#tab_extensions table{ border-collapse: collapse; } @@ -718,6 +718,10 @@ footer { margin-left: -0.8em; } +#img2img_copy_to_img2img, #img2img_copy_to_sketch, #img2img_copy_to_inpaint, #img2img_copy_to_inpaint_sketch{ + margin-left: 0em; +} + .inactive{ opacity: 0.5; } From 5560150fdaf5d974a122f0b226d6abe24dea12c0 Mon Sep 17 00:00:00 2001 From: EllangoK Date: Sat, 21 Jan 2023 16:58:45 -0500 Subject: [PATCH 004/151] aligns the axis buttons in x/y plot --- scripts/xy_grid.py | 2 +- style.css | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 8ff315a7..0caece09 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -307,7 +307,7 @@ class Script(scripts.Script): y_values = gr.Textbox(label="Y values", lines=1, elem_id=self.elem_id("y_values")) fill_y_button = ToolButton(value=fill_values_symbol, elem_id="xy_grid_fill_y_tool_button", visible=False) - with gr.Row(variant="compact"): + with gr.Row(variant="compact", elem_id="axis_options"): draw_legend = gr.Checkbox(label='Draw legend', value=True, elem_id=self.elem_id("draw_legend")) include_lone_images = gr.Checkbox(label='Include Separate Images', value=False, elem_id=self.elem_id("include_lone_images")) no_fixed_seeds = gr.Checkbox(label='Keep -1 for seeds', value=False, elem_id=self.elem_id("no_fixed_seeds")) diff --git a/style.css b/style.css index b215405d..bf8260d7 100644 --- a/style.css +++ b/style.css @@ -722,6 +722,10 @@ footer { margin-left: 0em; } +#axis_options { + margin-left: 0em; +} + .inactive{ opacity: 0.5; } From 43ac9ff205910e8207dfd45a842577344d399a92 Mon Sep 17 00:00:00 2001 From: Andrey <16777216c@gmail.com> Date: Sun, 22 Jan 2023 15:26:40 +0300 Subject: [PATCH 005/151] Split history extras.py to postprocessing.py --- modules/{extras.py => postprocessing.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename modules/{extras.py => postprocessing.py} (100%) diff --git a/modules/extras.py b/modules/postprocessing.py similarity index 100% rename from modules/extras.py rename to modules/postprocessing.py From b238b14ee459486c4734cc2899b83f547813a467 Mon Sep 17 00:00:00 2001 From: Andrey <16777216c@gmail.com> Date: Sun, 22 Jan 2023 15:26:40 +0300 Subject: [PATCH 006/151] Split history extras.py to postprocessing.py --- modules/{extras.py => temp} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename modules/{extras.py => temp} (100%) diff --git a/modules/extras.py b/modules/temp similarity index 100% rename from modules/extras.py rename to modules/temp From c56b36712289020a98f0c77794b9045a251ecd55 Mon Sep 17 00:00:00 2001 From: Andrey <16777216c@gmail.com> Date: Sun, 22 Jan 2023 15:26:41 +0300 Subject: [PATCH 007/151] Split history extras.py to postprocessing.py --- modules/{temp => extras.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename modules/{temp => extras.py} (100%) diff --git a/modules/temp b/modules/extras.py similarity index 100% rename from modules/temp rename to modules/extras.py From 68303c96e5ab31576a8238a24bf5b6191cf16ed1 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 22 Jan 2023 15:38:39 +0300 Subject: [PATCH 008/151] split oversize extras.py to postprocessing.py --- modules/extras.py | 217 +------------------------------- modules/postprocessing.py | 257 +------------------------------------- modules/ui.py | 10 +- modules/ui_components.py | 7 ++ webui.py | 1 - 5 files changed, 18 insertions(+), 474 deletions(-) diff --git a/modules/extras.py b/modules/extras.py index 385430dc..f04ddfc2 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -1,231 +1,16 @@ -from __future__ import annotations -import math import os import re -import sys -import traceback import shutil -import numpy as np -from PIL import Image import torch import tqdm -from typing import Callable, List, OrderedDict, Tuple -from functools import partial -from dataclasses import dataclass - -from modules import processing, shared, images, devices, sd_models, sd_samplers, sd_vae -from modules.shared import opts -import modules.gfpgan_model +from modules import shared, images, sd_models, sd_vae from modules.ui import plaintext_to_html -import modules.codeformer_model import gradio as gr import safetensors.torch -class LruCache(OrderedDict): - @dataclass(frozen=True) - class Key: - image_hash: int - info_hash: int - args_hash: int - - @dataclass - class Value: - image: Image.Image - info: str - - def __init__(self, max_size: int = 5, *args, **kwargs): - super().__init__(*args, **kwargs) - self._max_size = max_size - - def get(self, key: LruCache.Key) -> LruCache.Value: - ret = super().get(key) - if ret is not None: - self.move_to_end(key) # Move to end of eviction list - return ret - - def put(self, key: LruCache.Key, value: LruCache.Value) -> None: - self[key] = value - while len(self) > self._max_size: - self.popitem(last=False) - - -cached_images: LruCache = LruCache(max_size=5) - - -def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True): - devices.torch_gc() - - shared.state.begin() - shared.state.job = 'extras' - - imageArr = [] - # Also keep track of original file names - imageNameArr = [] - outputs = [] - - if extras_mode == 1: - #convert file to pillow image - for img in image_folder: - image = Image.open(img) - imageArr.append(image) - imageNameArr.append(os.path.splitext(img.orig_name)[0]) - elif extras_mode == 2: - assert not shared.cmd_opts.hide_ui_dir_config, '--hide-ui-dir-config option must be disabled' - - if input_dir == '': - return outputs, "Please select an input directory.", '' - image_list = shared.listfiles(input_dir) - for img in image_list: - try: - image = Image.open(img) - except Exception: - continue - imageArr.append(image) - imageNameArr.append(img) - else: - imageArr.append(image) - imageNameArr.append(None) - - if extras_mode == 2 and output_dir != '': - outpath = output_dir - else: - outpath = opts.outdir_samples or opts.outdir_extras_samples - - # Extra operation definitions - - def run_gfpgan(image: Image.Image, info: str) -> Tuple[Image.Image, str]: - shared.state.job = 'extras-gfpgan' - restored_img = modules.gfpgan_model.gfpgan_fix_faces(np.array(image, dtype=np.uint8)) - res = Image.fromarray(restored_img) - - if gfpgan_visibility < 1.0: - res = Image.blend(image, res, gfpgan_visibility) - - info += f"GFPGAN visibility:{round(gfpgan_visibility, 2)}\n" - return (res, info) - - def run_codeformer(image: Image.Image, info: str) -> Tuple[Image.Image, str]: - shared.state.job = 'extras-codeformer' - restored_img = modules.codeformer_model.codeformer.restore(np.array(image, dtype=np.uint8), w=codeformer_weight) - res = Image.fromarray(restored_img) - - if codeformer_visibility < 1.0: - res = Image.blend(image, res, codeformer_visibility) - - info += f"CodeFormer w: {round(codeformer_weight, 2)}, CodeFormer visibility:{round(codeformer_visibility, 2)}\n" - return (res, info) - - def upscale(image, scaler_index, resize, mode, resize_w, resize_h, crop): - shared.state.job = 'extras-upscale' - upscaler = shared.sd_upscalers[scaler_index] - res = upscaler.scaler.upscale(image, resize, upscaler.data_path) - if mode == 1 and crop: - cropped = Image.new("RGB", (resize_w, resize_h)) - cropped.paste(res, box=(resize_w // 2 - res.width // 2, resize_h // 2 - res.height // 2)) - res = cropped - return res - - def run_prepare_crop(image: Image.Image, info: str) -> Tuple[Image.Image, str]: - # Actual crop happens in run_upscalers_blend, this just sets upscaling_resize and adds info text - nonlocal upscaling_resize - if resize_mode == 1: - upscaling_resize = max(upscaling_resize_w/image.width, upscaling_resize_h/image.height) - crop_info = " (crop)" if upscaling_crop else "" - info += f"Resize to: {upscaling_resize_w:g}x{upscaling_resize_h:g}{crop_info}\n" - return (image, info) - - @dataclass - class UpscaleParams: - upscaler_idx: int - blend_alpha: float - - def run_upscalers_blend(params: List[UpscaleParams], image: Image.Image, info: str) -> Tuple[Image.Image, str]: - blended_result: Image.Image = None - image_hash: str = hash(np.array(image.getdata()).tobytes()) - for upscaler in params: - upscale_args = (upscaler.upscaler_idx, upscaling_resize, resize_mode, - upscaling_resize_w, upscaling_resize_h, upscaling_crop) - cache_key = LruCache.Key(image_hash=image_hash, - info_hash=hash(info), - args_hash=hash(upscale_args)) - cached_entry = cached_images.get(cache_key) - if cached_entry is None: - res = upscale(image, *upscale_args) - info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {upscaler.blend_alpha}, model:{shared.sd_upscalers[upscaler.upscaler_idx].name}\n" - cached_images.put(cache_key, LruCache.Value(image=res, info=info)) - else: - res, info = cached_entry.image, cached_entry.info - - if blended_result is None: - blended_result = res - else: - blended_result = Image.blend(blended_result, res, upscaler.blend_alpha) - return (blended_result, info) - - # Build a list of operations to run - facefix_ops: List[Callable] = [] - facefix_ops += [run_gfpgan] if gfpgan_visibility > 0 else [] - facefix_ops += [run_codeformer] if codeformer_visibility > 0 else [] - - upscale_ops: List[Callable] = [] - upscale_ops += [run_prepare_crop] if resize_mode == 1 else [] - - if upscaling_resize != 0: - step_params: List[UpscaleParams] = [] - step_params.append(UpscaleParams(upscaler_idx=extras_upscaler_1, blend_alpha=1.0)) - if extras_upscaler_2 != 0 and extras_upscaler_2_visibility > 0: - step_params.append(UpscaleParams(upscaler_idx=extras_upscaler_2, blend_alpha=extras_upscaler_2_visibility)) - - upscale_ops.append(partial(run_upscalers_blend, step_params)) - - extras_ops: List[Callable] = (upscale_ops + facefix_ops) if upscale_first else (facefix_ops + upscale_ops) - - for image, image_name in zip(imageArr, imageNameArr): - if image is None: - return outputs, "Please select an input image.", '' - - shared.state.textinfo = f'Processing image {image_name}' - - existing_pnginfo = image.info or {} - - image = image.convert("RGB") - info = "" - # Run each operation on each image - for op in extras_ops: - image, info = op(image, info) - - if opts.use_original_name_batch and image_name is not None: - basename = os.path.splitext(os.path.basename(image_name))[0] - else: - basename = '' - - if opts.enable_pnginfo: # append info before save - image.info = existing_pnginfo - image.info["extras"] = info - - if save_output: - # Add upscaler name as a suffix. - suffix = f"-{shared.sd_upscalers[extras_upscaler_1].name}" if shared.opts.use_upscaler_name_as_suffix else "" - # Add second upscaler if applicable. - if suffix and extras_upscaler_2 and extras_upscaler_2_visibility: - suffix += f"-{shared.sd_upscalers[extras_upscaler_2].name}" - - images.save_image(image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True, - no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None, suffix=suffix) - - if extras_mode != 2 or show_extras_results : - outputs.append(image) - - devices.torch_gc() - - return outputs, plaintext_to_html(info), '' - -def clear_cache(): - cached_images.clear() - def run_pnginfo(image): if image is None: diff --git a/modules/postprocessing.py b/modules/postprocessing.py index 385430dc..cb85720b 100644 --- a/modules/postprocessing.py +++ b/modules/postprocessing.py @@ -1,28 +1,18 @@ from __future__ import annotations -import math import os -import re -import sys -import traceback -import shutil import numpy as np from PIL import Image -import torch -import tqdm - from typing import Callable, List, OrderedDict, Tuple from functools import partial from dataclasses import dataclass -from modules import processing, shared, images, devices, sd_models, sd_samplers, sd_vae +from modules import shared, images, devices, ui_components from modules.shared import opts import modules.gfpgan_model -from modules.ui import plaintext_to_html import modules.codeformer_model -import gradio as gr -import safetensors.torch + class LruCache(OrderedDict): @dataclass(frozen=True) @@ -55,7 +45,7 @@ class LruCache(OrderedDict): cached_images: LruCache = LruCache(max_size=5) -def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True): +def run_postprocessing(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True): devices.torch_gc() shared.state.begin() @@ -221,246 +211,9 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_ devices.torch_gc() - return outputs, plaintext_to_html(info), '' + return outputs, ui_components.plaintext_to_html(info), '' + def clear_cache(): cached_images.clear() - -def run_pnginfo(image): - if image is None: - return '', '', '' - - geninfo, items = images.read_info_from_image(image) - items = {**{'parameters': geninfo}, **items} - - info = '' - for key, text in items.items(): - info += f""" -
-

{plaintext_to_html(str(key))}

-

{plaintext_to_html(str(text))}

-
-""".strip()+"\n" - - if len(info) == 0: - message = "Nothing found in the image." - info = f"

{message}

" - - return '', geninfo, info - - -def create_config(ckpt_result, config_source, a, b, c): - def config(x): - res = sd_models.find_checkpoint_config(x) if x else None - return res if res != shared.sd_default_config else None - - if config_source == 0: - cfg = config(a) or config(b) or config(c) - elif config_source == 1: - cfg = config(b) - elif config_source == 2: - cfg = config(c) - else: - cfg = None - - if cfg is None: - return - - filename, _ = os.path.splitext(ckpt_result) - checkpoint_filename = filename + ".yaml" - - print("Copying config:") - print(" from:", cfg) - print(" to:", checkpoint_filename) - shutil.copyfile(cfg, checkpoint_filename) - - -checkpoint_dict_skip_on_merge = ["cond_stage_model.transformer.text_model.embeddings.position_ids"] - - -def to_half(tensor, enable): - if enable and tensor.dtype == torch.float: - return tensor.half() - - return tensor - - -def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights): - shared.state.begin() - shared.state.job = 'model-merge' - - def fail(message): - shared.state.textinfo = message - shared.state.end() - return [*[gr.update() for _ in range(4)], message] - - def weighted_sum(theta0, theta1, alpha): - return ((1 - alpha) * theta0) + (alpha * theta1) - - def get_difference(theta1, theta2): - return theta1 - theta2 - - def add_difference(theta0, theta1_2_diff, alpha): - return theta0 + (alpha * theta1_2_diff) - - def filename_weighted_sum(): - a = primary_model_info.model_name - b = secondary_model_info.model_name - Ma = round(1 - multiplier, 2) - Mb = round(multiplier, 2) - - return f"{Ma}({a}) + {Mb}({b})" - - def filename_add_difference(): - a = primary_model_info.model_name - b = secondary_model_info.model_name - c = tertiary_model_info.model_name - M = round(multiplier, 2) - - return f"{a} + {M}({b} - {c})" - - def filename_nothing(): - return primary_model_info.model_name - - theta_funcs = { - "Weighted sum": (filename_weighted_sum, None, weighted_sum), - "Add difference": (filename_add_difference, get_difference, add_difference), - "No interpolation": (filename_nothing, None, None), - } - filename_generator, theta_func1, theta_func2 = theta_funcs[interp_method] - shared.state.job_count = (1 if theta_func1 else 0) + (1 if theta_func2 else 0) - - if not primary_model_name: - return fail("Failed: Merging requires a primary model.") - - primary_model_info = sd_models.checkpoints_list[primary_model_name] - - if theta_func2 and not secondary_model_name: - return fail("Failed: Merging requires a secondary model.") - - secondary_model_info = sd_models.checkpoints_list[secondary_model_name] if theta_func2 else None - - if theta_func1 and not tertiary_model_name: - return fail(f"Failed: Interpolation method ({interp_method}) requires a tertiary model.") - - tertiary_model_info = sd_models.checkpoints_list[tertiary_model_name] if theta_func1 else None - - result_is_inpainting_model = False - - if theta_func2: - shared.state.textinfo = f"Loading B" - print(f"Loading {secondary_model_info.filename}...") - theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu') - else: - theta_1 = None - - if theta_func1: - shared.state.textinfo = f"Loading C" - print(f"Loading {tertiary_model_info.filename}...") - theta_2 = sd_models.read_state_dict(tertiary_model_info.filename, map_location='cpu') - - shared.state.textinfo = 'Merging B and C' - shared.state.sampling_steps = len(theta_1.keys()) - for key in tqdm.tqdm(theta_1.keys()): - if key in checkpoint_dict_skip_on_merge: - continue - - if 'model' in key: - if key in theta_2: - t2 = theta_2.get(key, torch.zeros_like(theta_1[key])) - theta_1[key] = theta_func1(theta_1[key], t2) - else: - theta_1[key] = torch.zeros_like(theta_1[key]) - - shared.state.sampling_step += 1 - del theta_2 - - shared.state.nextjob() - - shared.state.textinfo = f"Loading {primary_model_info.filename}..." - print(f"Loading {primary_model_info.filename}...") - theta_0 = sd_models.read_state_dict(primary_model_info.filename, map_location='cpu') - - print("Merging...") - shared.state.textinfo = 'Merging A and B' - shared.state.sampling_steps = len(theta_0.keys()) - for key in tqdm.tqdm(theta_0.keys()): - if theta_1 and 'model' in key and key in theta_1: - - if key in checkpoint_dict_skip_on_merge: - continue - - a = theta_0[key] - b = theta_1[key] - - # this enables merging an inpainting model (A) with another one (B); - # where normal model would have 4 channels, for latenst space, inpainting model would - # have another 4 channels for unmasked picture's latent space, plus one channel for mask, for a total of 9 - if a.shape != b.shape and a.shape[0:1] + a.shape[2:] == b.shape[0:1] + b.shape[2:]: - if a.shape[1] == 4 and b.shape[1] == 9: - raise RuntimeError("When merging inpainting model with a normal one, A must be the inpainting model.") - - assert a.shape[1] == 9 and b.shape[1] == 4, f"Bad dimensions for merged layer {key}: A={a.shape}, B={b.shape}" - - theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier) - result_is_inpainting_model = True - else: - theta_0[key] = theta_func2(a, b, multiplier) - - theta_0[key] = to_half(theta_0[key], save_as_half) - - shared.state.sampling_step += 1 - - del theta_1 - - bake_in_vae_filename = sd_vae.vae_dict.get(bake_in_vae, None) - if bake_in_vae_filename is not None: - print(f"Baking in VAE from {bake_in_vae_filename}") - shared.state.textinfo = 'Baking in VAE' - vae_dict = sd_vae.load_vae_dict(bake_in_vae_filename, map_location='cpu') - - for key in vae_dict.keys(): - theta_0_key = 'first_stage_model.' + key - if theta_0_key in theta_0: - theta_0[theta_0_key] = to_half(vae_dict[key], save_as_half) - - del vae_dict - - if save_as_half and not theta_func2: - for key in theta_0.keys(): - theta_0[key] = to_half(theta_0[key], save_as_half) - - if discard_weights: - regex = re.compile(discard_weights) - for key in list(theta_0): - if re.search(regex, key): - theta_0.pop(key, None) - - ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path - - filename = filename_generator() if custom_name == '' else custom_name - filename += ".inpainting" if result_is_inpainting_model else "" - filename += "." + checkpoint_format - - output_modelname = os.path.join(ckpt_dir, filename) - - shared.state.nextjob() - shared.state.textinfo = "Saving" - print(f"Saving to {output_modelname}...") - - _, extension = os.path.splitext(output_modelname) - if extension.lower() == ".safetensors": - safetensors.torch.save_file(theta_0, output_modelname, metadata={"format": "pt"}) - else: - torch.save(theta_0, output_modelname) - - sd_models.list_models() - - create_config(output_modelname, config_source, primary_model_info, secondary_model_info, tertiary_model_info) - - print(f"Checkpoint saved to {output_modelname}.") - shared.state.textinfo = "Checkpoint saved" - shared.state.end() - - return [*[gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)], "Checkpoint saved to " + output_modelname] diff --git a/modules/ui.py b/modules/ui.py index eb4b7e6b..4116e167 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -20,7 +20,7 @@ import numpy as np from PIL import Image, PngImagePlugin from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call -from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks +from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, postprocessing, ui_components from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML from modules.paths import script_path @@ -95,8 +95,8 @@ extra_networks_symbol = '\U0001F3B4' # 🎴 def plaintext_to_html(text): - text = "

" + "
\n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "

" - return text + return ui_components.plaintext_to_html(text) + def send_gradio_gallery_to_image(x): if len(x) == 0: @@ -1152,7 +1152,7 @@ def create_ui(): result_images, html_info_x, html_info, html_log = create_output_panel("extras", opts.outdir_extras_samples) submit.click( - fn=wrap_gradio_gpu_call(modules.extras.run_extras, extra_outputs=[None, '']), + fn=wrap_gradio_gpu_call(postprocessing.run_postprocessing, extra_outputs=[None, '']), _js="get_extras_tab_index", inputs=[ dummy_component, @@ -1183,7 +1183,7 @@ def create_ui(): parameters_copypaste.add_paste_fields("extras", extras_image, None) extras_image.change( - fn=modules.extras.clear_cache, + fn=postprocessing.clear_cache, inputs=[], outputs=[] ) diff --git a/modules/ui_components.py b/modules/ui_components.py index 46324425..989cc87b 100644 --- a/modules/ui_components.py +++ b/modules/ui_components.py @@ -1,3 +1,5 @@ +import html + import gradio as gr @@ -47,3 +49,8 @@ class FormColorPicker(gr.ColorPicker, gr.components.FormComponent): def get_block_name(self): return "colorpicker" + + +def plaintext_to_html(text): + text = "

" + "
\n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "

" + return text diff --git a/webui.py b/webui.py index d235da74..7cf5885e 100644 --- a/webui.py +++ b/webui.py @@ -22,7 +22,6 @@ if ".dev" in torch.__version__ or "+git" in torch.__version__: from modules import shared, devices, sd_samplers, upscaler, extensions, localization, ui_tempdir, ui_extra_networks import modules.codeformer_model as codeformer -import modules.extras import modules.face_restoration import modules.gfpgan_model as gfpgan import modules.img2img From 985c0b8e9abdd67734d638badefb6ea806b1f28b Mon Sep 17 00:00:00 2001 From: Guillermo Moreno Date: Sat, 21 Jan 2023 17:45:36 -0300 Subject: [PATCH 009/151] feat(extra-networks): add thumbs view style --- html/image-update.svg | 3 ++ javascript/extraNetworks.js | 2 ++ modules/ui_extra_networks.py | 21 +++++++----- style.css | 64 ++++++++++++++++++++++++++++++++++-- 4 files changed, 78 insertions(+), 12 deletions(-) create mode 100644 html/image-update.svg diff --git a/html/image-update.svg b/html/image-update.svg new file mode 100644 index 00000000..525e4fc5 --- /dev/null +++ b/html/image-update.svg @@ -0,0 +1,3 @@ + + + diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index c5a9adb3..1bda7c6e 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -6,11 +6,13 @@ function setupExtraNetworksForTab(tabname){ var search = gradioApp().querySelector('#'+tabname+'_extra_search textarea') var refresh = gradioApp().getElementById(tabname+'_extra_refresh') var close = gradioApp().getElementById(tabname+'_extra_close') + var view = gradioApp().getElementById(tabname+'_extra_view') search.classList.add('search') tabs.appendChild(search) tabs.appendChild(refresh) tabs.appendChild(close) + tabs.appendChild(view) search.addEventListener("input", function(evt){ searchTerm = search.value.toLowerCase() diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index af2b8071..ce4801b5 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -25,7 +25,7 @@ class ExtraNetworksPage: def refresh(self): pass - def create_html(self, tabname): + def create_html(self, tabname, view = 'cards'): items_html = '' for item in self.list_items(): @@ -36,7 +36,7 @@ class ExtraNetworksPage: items_html = shared.html("extra-networks-no-cards.html").format(dirs=dirs) res = f""" -
+
{items_html}
""" @@ -75,6 +75,7 @@ class ExtraNetworksUi: self.button_save_preview = None self.preview_target_filename = None + self.view_dropdown = None self.tabname = None @@ -110,6 +111,7 @@ def create_ui(container, button, tabname): filter = gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", placeholder="Search...", visible=False) button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh") button_close = gr.Button('Close', elem_id=tabname+"_extra_close") + ui.view_dropdown = gr.Dropdown(['cards', 'thumbs'], elem_id=tabname+"_extra_view", label="View as", value='cards') ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False) ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=tabname+"_preview_filename", visible=False) @@ -117,16 +119,17 @@ def create_ui(container, button, tabname): button.click(fn=lambda: gr.update(visible=True), inputs=[], outputs=[container]) button_close.click(fn=lambda: gr.update(visible=False), inputs=[], outputs=[container]) - def refresh(): + def refresh(view='cards'): res = [] for pg in ui.stored_extra_pages: pg.refresh() - res.append(pg.create_html(ui.tabname)) + res.append(pg.create_html(ui.tabname, view)) return res - button_refresh.click(fn=refresh, inputs=[], outputs=ui.pages) + ui.view_dropdown.change(fn=refresh, inputs=[ui.view_dropdown], outputs=ui.pages) + button_refresh.click(fn=refresh, inputs=[ui.view_dropdown], outputs=ui.pages) return ui @@ -139,7 +142,7 @@ def path_is_parent(parent_path, child_path): def setup_ui(ui, gallery): - def save_preview(index, images, filename): + def save_preview(index, images, filename, view='cards'): if len(images) == 0: print("There is no image in gallery to save as a preview.") return [page.create_html(ui.tabname) for page in ui.stored_extra_pages] @@ -161,11 +164,11 @@ def setup_ui(ui, gallery): image.save(filename) - return [page.create_html(ui.tabname) for page in ui.stored_extra_pages] + return [page.create_html(ui.tabname, view) for page in ui.stored_extra_pages] ui.button_save_preview.click( fn=save_preview, - _js="function(x, y, z){console.log(x, y, z); return [selected_gallery_index(), y, z]}", - inputs=[ui.preview_target_filename, gallery, ui.preview_target_filename], + _js="function(x, y, z, a){console.log(x, y, z, a); return [selected_gallery_index(), y, z, a]}", + inputs=[ui.preview_target_filename, gallery, ui.preview_target_filename, ui.view_dropdown], outputs=[*ui.pages] ) diff --git a/style.css b/style.css index 507acec1..ca0a172b 100644 --- a/style.css +++ b/style.css @@ -784,21 +784,79 @@ footer { display: inline-block; max-width: 16em; margin: 0.3em; + align-self: center; } -.extra-network-cards .nocards{ +#txt2img_extra_view, #img2img_extra_view { + width: auto; +} + +.extra-network-cards .nocards, .extra-network-thumbs .nocards{ margin: 1.25em 0.5em 0.5em 0.5em; } -.extra-network-cards .nocards h1{ +.extra-network-cards .nocards h1, .extra-network-thumbs .nocards h1{ font-size: 1.5em; margin-bottom: 1em; } -.extra-network-cards .nocards li{ +.extra-network-cards .nocards li, .extra-network-thumbs .nocards li{ margin-left: 0.5em; } +.extra-network-thumbs { + display: flex; + flex-flow: row wrap; + gap: 10px; +} + +.extra-network-thumbs .card { + height: 6em; + width: 6em; + cursor: pointer; + background-image: url('./file=html/card-no-preview.png'); + background-size: cover; + background-position: center center; + position: relative; +} + +.extra-network-thumbs .card:hover .additional a { + display: block; +} + +.extra-network-thumbs .actions .additional a { + background-image: url('./file=html/image-update.svg'); + background-repeat: no-repeat; + background-size: cover; + background-position: center center; + position: absolute; + top: 0; + left: 0; + width: 24px; + height: 24px; + display: none; + font-size: 0; + text-align: -9999; + background-color: #fff; +} + +.extra-network-thumbs .actions .name { + position: absolute; + bottom: 0; + font-size: 10px; + padding: 3px; + width: 100%; + overflow: hidden; + white-space: nowrap; + text-overflow: ellipsis; + background: rgba(0,0,0,.5); +} + +.extra-network-thumbs .card:hover .actions .name { + white-space: normal; + word-break: break-all; +} + .extra-network-cards .card{ display: inline-block; margin: 0.5em; From 66eef11ce7f3db108225668c573cb4a763a43fb3 Mon Sep 17 00:00:00 2001 From: Guillermo Moreno Date: Sat, 21 Jan 2023 18:27:57 -0300 Subject: [PATCH 010/151] feat(extra-networks): add default view setting --- modules/shared.py | 4 ++++ modules/ui_extra_networks.py | 8 ++++---- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/modules/shared.py b/modules/shared.py index cd78e50a..e9548864 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -430,6 +430,10 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"), "deepbooru_filter_tags": OptionInfo("", "filter out those tags from deepbooru output (separated by comma)"), })) +options_templates.update(options_section(('extra_networks', "Extra Networks"), { + "extra_networks_default_view": OptionInfo("cards", "Default view for Extra Networks", gr.Dropdown, { "choices": ["cards", "thumbs"] }), +})) + options_templates.update(options_section(('ui', "User interface"), { "return_grid": OptionInfo(True, "Show grid in results for web"), "do_not_show_images": OptionInfo(False, "Do not show any images in results for web"), diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index ce4801b5..179ba47a 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -25,7 +25,7 @@ class ExtraNetworksPage: def refresh(self): pass - def create_html(self, tabname, view = 'cards'): + def create_html(self, tabname, view=shared.opts.extra_networks_default_view): items_html = '' for item in self.list_items(): @@ -111,7 +111,7 @@ def create_ui(container, button, tabname): filter = gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", placeholder="Search...", visible=False) button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh") button_close = gr.Button('Close', elem_id=tabname+"_extra_close") - ui.view_dropdown = gr.Dropdown(['cards', 'thumbs'], elem_id=tabname+"_extra_view", label="View as", value='cards') + ui.view_dropdown = gr.Dropdown(['cards', 'thumbs'], elem_id=tabname+"_extra_view", label="View as", value=lambda: shared.opts.extra_networks_default_view) ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False) ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=tabname+"_preview_filename", visible=False) @@ -119,7 +119,7 @@ def create_ui(container, button, tabname): button.click(fn=lambda: gr.update(visible=True), inputs=[], outputs=[container]) button_close.click(fn=lambda: gr.update(visible=False), inputs=[], outputs=[container]) - def refresh(view='cards'): + def refresh(view): res = [] for pg in ui.stored_extra_pages: @@ -142,7 +142,7 @@ def path_is_parent(parent_path, child_path): def setup_ui(ui, gallery): - def save_preview(index, images, filename, view='cards'): + def save_preview(index, images, filename, view): if len(images) == 0: print("There is no image in gallery to save as a preview.") return [page.create_html(ui.tabname) for page in ui.stored_extra_pages] From 8a3f85c4cc1910f59e04b5c8355a30c4c42431e5 Mon Sep 17 00:00:00 2001 From: EllangoK Date: Sun, 22 Jan 2023 17:08:08 -0500 Subject: [PATCH 011/151] adds hires steps to x/y plot and fixes total_steps calculation --- scripts/xy_grid.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 8ff315a7..5990b78d 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -184,6 +184,7 @@ axis_options = [ AxisOption("Var. seed", int, apply_field("subseed")), AxisOption("Var. strength", float, apply_field("subseed_strength")), AxisOption("Steps", int, apply_field("steps")), + AxisOptionTxt2Img("Hires steps", int, apply_field("hr_second_pass_steps")), AxisOption("CFG Scale", float, apply_field("cfg_scale")), AxisOption("Prompt S/R", str, apply_prompt, format_value=format_value), AxisOption("Prompt order", str_permutations, apply_order, format_value=format_value_join_list), @@ -427,10 +428,21 @@ class Script(scripts.Script): total_steps = p.steps * len(xs) * len(ys) if isinstance(p, StableDiffusionProcessingTxt2Img) and p.enable_hr: - total_steps *= 2 + if x_opt.label == "Hires steps": + total_steps += sum(xs) * len(ys) + elif y_opt.label == "Hires steps": + total_steps += sum(ys) * len(xs) + elif p.hr_second_pass_steps: + total_steps += p.hr_second_pass_steps * len(xs) * len(ys) + else: + total_steps *= 2 - print(f"X/Y plot will create {len(xs) * len(ys) * p.n_iter} images on a {len(xs)}x{len(ys)} grid. (Total steps to process: {total_steps * p.n_iter})") - shared.total_tqdm.updateTotal(total_steps * p.n_iter) + total_steps *= p.n_iter + + image_cell_count = p.n_iter * p.batch_size + cell_console_text = f"; {image_cell_count} images per cell" if image_cell_count > 1 else "" + print(f"X/Y plot will create {len(xs) * len(ys) * image_cell_count} images on a {len(xs)}x{len(ys)} grid{cell_console_text}. (Total steps to process: {total_steps})") + shared.total_tqdm.updateTotal(total_steps) grid_infotext = [None] From f80ff3c1e444926879c284be9384a26ca38d4955 Mon Sep 17 00:00:00 2001 From: Guillermo Moreno Date: Sun, 22 Jan 2023 22:01:24 -0300 Subject: [PATCH 012/151] feat(extra-networks): remove view dropdown --- javascript/extraNetworks.js | 2 -- modules/ui_extra_networks.py | 20 +++++++++----------- 2 files changed, 9 insertions(+), 13 deletions(-) diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index 1bda7c6e..c5a9adb3 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -6,13 +6,11 @@ function setupExtraNetworksForTab(tabname){ var search = gradioApp().querySelector('#'+tabname+'_extra_search textarea') var refresh = gradioApp().getElementById(tabname+'_extra_refresh') var close = gradioApp().getElementById(tabname+'_extra_close') - var view = gradioApp().getElementById(tabname+'_extra_view') search.classList.add('search') tabs.appendChild(search) tabs.appendChild(refresh) tabs.appendChild(close) - tabs.appendChild(view) search.addEventListener("input", function(evt){ searchTerm = search.value.toLowerCase() diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 179ba47a..2ddac3d8 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -25,7 +25,8 @@ class ExtraNetworksPage: def refresh(self): pass - def create_html(self, tabname, view=shared.opts.extra_networks_default_view): + def create_html(self, tabname): + view = shared.opts.extra_networks_default_view items_html = '' for item in self.list_items(): @@ -75,7 +76,6 @@ class ExtraNetworksUi: self.button_save_preview = None self.preview_target_filename = None - self.view_dropdown = None self.tabname = None @@ -111,7 +111,6 @@ def create_ui(container, button, tabname): filter = gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", placeholder="Search...", visible=False) button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh") button_close = gr.Button('Close', elem_id=tabname+"_extra_close") - ui.view_dropdown = gr.Dropdown(['cards', 'thumbs'], elem_id=tabname+"_extra_view", label="View as", value=lambda: shared.opts.extra_networks_default_view) ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False) ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=tabname+"_preview_filename", visible=False) @@ -119,17 +118,16 @@ def create_ui(container, button, tabname): button.click(fn=lambda: gr.update(visible=True), inputs=[], outputs=[container]) button_close.click(fn=lambda: gr.update(visible=False), inputs=[], outputs=[container]) - def refresh(view): + def refresh(): res = [] for pg in ui.stored_extra_pages: pg.refresh() - res.append(pg.create_html(ui.tabname, view)) + res.append(pg.create_html(ui.tabname)) return res - ui.view_dropdown.change(fn=refresh, inputs=[ui.view_dropdown], outputs=ui.pages) - button_refresh.click(fn=refresh, inputs=[ui.view_dropdown], outputs=ui.pages) + button_refresh.click(fn=refresh, inputs=[], outputs=ui.pages) return ui @@ -142,7 +140,7 @@ def path_is_parent(parent_path, child_path): def setup_ui(ui, gallery): - def save_preview(index, images, filename, view): + def save_preview(index, images, filename): if len(images) == 0: print("There is no image in gallery to save as a preview.") return [page.create_html(ui.tabname) for page in ui.stored_extra_pages] @@ -164,11 +162,11 @@ def setup_ui(ui, gallery): image.save(filename) - return [page.create_html(ui.tabname, view) for page in ui.stored_extra_pages] + return [page.create_html(ui.tabname) for page in ui.stored_extra_pages] ui.button_save_preview.click( fn=save_preview, - _js="function(x, y, z, a){console.log(x, y, z, a); return [selected_gallery_index(), y, z, a]}", - inputs=[ui.preview_target_filename, gallery, ui.preview_target_filename, ui.view_dropdown], + _js="function(x, y, z){console.log(x, y, z); return [selected_gallery_index(), y, z]}", + inputs=[ui.preview_target_filename, gallery, ui.preview_target_filename], outputs=[*ui.pages] ) From b5230197a69d36a79fdc4919c59a03e00e872dd3 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 23 Jan 2023 09:24:43 +0300 Subject: [PATCH 013/151] rework extras tab to use script system --- javascript/ui.js | 5 - modules/api/api.py | 13 +- modules/postprocessing.py | 226 ++++++----------------- modules/scripts.py | 28 ++- modules/scripts_postprocessing.py | 147 +++++++++++++++ modules/shared.py | 5 + modules/ui.py | 265 +-------------------------- modules/ui_common.py | 202 ++++++++++++++++++++ modules/ui_components.py | 6 - modules/ui_postprocessing.py | 57 ++++++ scripts/postprocessing_codeformer.py | 36 ++++ scripts/postprocessing_gfpgan.py | 33 ++++ scripts/postprocessing_upscale.py | 106 +++++++++++ 13 files changed, 670 insertions(+), 459 deletions(-) create mode 100644 modules/scripts_postprocessing.py create mode 100644 modules/ui_common.py create mode 100644 modules/ui_postprocessing.py create mode 100644 scripts/postprocessing_codeformer.py create mode 100644 scripts/postprocessing_gfpgan.py create mode 100644 scripts/postprocessing_upscale.py diff --git a/javascript/ui.js b/javascript/ui.js index 77256e15..ba72623c 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -104,11 +104,6 @@ function create_tab_index_args(tabId, args){ return res } -function get_extras_tab_index(){ - const [,,...args] = [...arguments] - return [get_tab_index('mode_extras'), get_tab_index('extras_resize_mode'), ...args] -} - function get_img2img_tab_index() { let res = args_to_array(arguments) res.splice(-2) diff --git a/modules/api/api.py b/modules/api/api.py index f2e9e884..5d60fc0a 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -11,10 +11,9 @@ from fastapi.security import HTTPBasic, HTTPBasicCredentials from secrets import compare_digest import modules.shared as shared -from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui +from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing from modules.api.models import * from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images -from modules.extras import run_extras from modules.textual_inversion.textual_inversion import create_embedding, train_embedding from modules.textual_inversion.preprocess import preprocess from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork @@ -45,10 +44,8 @@ def validate_sampler_name(name): def setUpscalers(req: dict): reqDict = vars(req) - reqDict['extras_upscaler_1'] = upscaler_to_index(req.upscaler_1) - reqDict['extras_upscaler_2'] = upscaler_to_index(req.upscaler_2) - reqDict.pop('upscaler_1') - reqDict.pop('upscaler_2') + reqDict['extras_upscaler_1'] = reqDict.pop('upscaler_1', None) + reqDict['extras_upscaler_2'] = reqDict.pop('upscaler_2', None) return reqDict def decode_base64_to_image(encoding): @@ -244,7 +241,7 @@ class Api: reqDict['image'] = decode_base64_to_image(reqDict['image']) with self.queue_lock: - result = run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", save_output=False, **reqDict) + result = postprocessing.run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", save_output=False, **reqDict) return ExtrasSingleImageResponse(image=encode_pil_to_base64(result[0][0]), html_info=result[1]) @@ -260,7 +257,7 @@ class Api: reqDict.pop('imageList') with self.queue_lock: - result = run_extras(extras_mode=1, image="", input_dir="", output_dir="", save_output=False, **reqDict) + result = postprocessing.run_extras(extras_mode=1, image="", input_dir="", output_dir="", save_output=False, **reqDict) return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1]) diff --git a/modules/postprocessing.py b/modules/postprocessing.py index cb85720b..8514fea7 100644 --- a/modules/postprocessing.py +++ b/modules/postprocessing.py @@ -1,219 +1,103 @@ -from __future__ import annotations import os -import numpy as np from PIL import Image -from typing import Callable, List, OrderedDict, Tuple -from functools import partial -from dataclasses import dataclass - -from modules import shared, images, devices, ui_components +from modules import shared, images, devices, scripts, scripts_postprocessing, ui_common, generation_parameters_copypaste from modules.shared import opts -import modules.gfpgan_model -import modules.codeformer_model -class LruCache(OrderedDict): - @dataclass(frozen=True) - class Key: - image_hash: int - info_hash: int - args_hash: int - - @dataclass - class Value: - image: Image.Image - info: str - - def __init__(self, max_size: int = 5, *args, **kwargs): - super().__init__(*args, **kwargs) - self._max_size = max_size - - def get(self, key: LruCache.Key) -> LruCache.Value: - ret = super().get(key) - if ret is not None: - self.move_to_end(key) # Move to end of eviction list - return ret - - def put(self, key: LruCache.Key, value: LruCache.Value) -> None: - self[key] = value - while len(self) > self._max_size: - self.popitem(last=False) - - -cached_images: LruCache = LruCache(max_size=5) - - -def run_postprocessing(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True): +def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output: bool = True): devices.torch_gc() shared.state.begin() shared.state.job = 'extras' - imageArr = [] - # Also keep track of original file names - imageNameArr = [] + image_data = [] + image_names = [] outputs = [] if extras_mode == 1: - #convert file to pillow image for img in image_folder: image = Image.open(img) - imageArr.append(image) - imageNameArr.append(os.path.splitext(img.orig_name)[0]) + image_data.append(image) + image_names.append(os.path.splitext(img.orig_name)[0]) elif extras_mode == 2: assert not shared.cmd_opts.hide_ui_dir_config, '--hide-ui-dir-config option must be disabled' + assert input_dir, 'input directory not selected' - if input_dir == '': - return outputs, "Please select an input directory.", '' image_list = shared.listfiles(input_dir) - for img in image_list: + for filename in image_list: try: - image = Image.open(img) + image = Image.open(filename) except Exception: continue - imageArr.append(image) - imageNameArr.append(img) + image_data.append(image) + image_names.append(filename) else: - imageArr.append(image) - imageNameArr.append(None) + assert image, 'image not selected' + + image_data.append(image) + image_names.append(None) if extras_mode == 2 and output_dir != '': outpath = output_dir else: outpath = opts.outdir_samples or opts.outdir_extras_samples - # Extra operation definitions + infotext = '' - def run_gfpgan(image: Image.Image, info: str) -> Tuple[Image.Image, str]: - shared.state.job = 'extras-gfpgan' - restored_img = modules.gfpgan_model.gfpgan_fix_faces(np.array(image, dtype=np.uint8)) - res = Image.fromarray(restored_img) + for image, name in zip(image_data, image_names): + shared.state.textinfo = name - if gfpgan_visibility < 1.0: - res = Image.blend(image, res, gfpgan_visibility) - - info += f"GFPGAN visibility:{round(gfpgan_visibility, 2)}\n" - return (res, info) - - def run_codeformer(image: Image.Image, info: str) -> Tuple[Image.Image, str]: - shared.state.job = 'extras-codeformer' - restored_img = modules.codeformer_model.codeformer.restore(np.array(image, dtype=np.uint8), w=codeformer_weight) - res = Image.fromarray(restored_img) - - if codeformer_visibility < 1.0: - res = Image.blend(image, res, codeformer_visibility) - - info += f"CodeFormer w: {round(codeformer_weight, 2)}, CodeFormer visibility:{round(codeformer_visibility, 2)}\n" - return (res, info) - - def upscale(image, scaler_index, resize, mode, resize_w, resize_h, crop): - shared.state.job = 'extras-upscale' - upscaler = shared.sd_upscalers[scaler_index] - res = upscaler.scaler.upscale(image, resize, upscaler.data_path) - if mode == 1 and crop: - cropped = Image.new("RGB", (resize_w, resize_h)) - cropped.paste(res, box=(resize_w // 2 - res.width // 2, resize_h // 2 - res.height // 2)) - res = cropped - return res - - def run_prepare_crop(image: Image.Image, info: str) -> Tuple[Image.Image, str]: - # Actual crop happens in run_upscalers_blend, this just sets upscaling_resize and adds info text - nonlocal upscaling_resize - if resize_mode == 1: - upscaling_resize = max(upscaling_resize_w/image.width, upscaling_resize_h/image.height) - crop_info = " (crop)" if upscaling_crop else "" - info += f"Resize to: {upscaling_resize_w:g}x{upscaling_resize_h:g}{crop_info}\n" - return (image, info) - - @dataclass - class UpscaleParams: - upscaler_idx: int - blend_alpha: float - - def run_upscalers_blend(params: List[UpscaleParams], image: Image.Image, info: str) -> Tuple[Image.Image, str]: - blended_result: Image.Image = None - image_hash: str = hash(np.array(image.getdata()).tobytes()) - for upscaler in params: - upscale_args = (upscaler.upscaler_idx, upscaling_resize, resize_mode, - upscaling_resize_w, upscaling_resize_h, upscaling_crop) - cache_key = LruCache.Key(image_hash=image_hash, - info_hash=hash(info), - args_hash=hash(upscale_args)) - cached_entry = cached_images.get(cache_key) - if cached_entry is None: - res = upscale(image, *upscale_args) - info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {upscaler.blend_alpha}, model:{shared.sd_upscalers[upscaler.upscaler_idx].name}\n" - cached_images.put(cache_key, LruCache.Value(image=res, info=info)) - else: - res, info = cached_entry.image, cached_entry.info - - if blended_result is None: - blended_result = res - else: - blended_result = Image.blend(blended_result, res, upscaler.blend_alpha) - return (blended_result, info) - - # Build a list of operations to run - facefix_ops: List[Callable] = [] - facefix_ops += [run_gfpgan] if gfpgan_visibility > 0 else [] - facefix_ops += [run_codeformer] if codeformer_visibility > 0 else [] - - upscale_ops: List[Callable] = [] - upscale_ops += [run_prepare_crop] if resize_mode == 1 else [] - - if upscaling_resize != 0: - step_params: List[UpscaleParams] = [] - step_params.append(UpscaleParams(upscaler_idx=extras_upscaler_1, blend_alpha=1.0)) - if extras_upscaler_2 != 0 and extras_upscaler_2_visibility > 0: - step_params.append(UpscaleParams(upscaler_idx=extras_upscaler_2, blend_alpha=extras_upscaler_2_visibility)) - - upscale_ops.append(partial(run_upscalers_blend, step_params)) - - extras_ops: List[Callable] = (upscale_ops + facefix_ops) if upscale_first else (facefix_ops + upscale_ops) - - for image, image_name in zip(imageArr, imageNameArr): - if image is None: - return outputs, "Please select an input image.", '' - - shared.state.textinfo = f'Processing image {image_name}' - existing_pnginfo = image.info or {} - image = image.convert("RGB") - info = "" - # Run each operation on each image - for op in extras_ops: - image, info = op(image, info) + pp = scripts_postprocessing.PostprocessedImage(image.convert("RGB")) - if opts.use_original_name_batch and image_name is not None: - basename = os.path.splitext(os.path.basename(image_name))[0] + scripts.scripts_postproc.run(pp, args) + + if opts.use_original_name_batch and name is not None: + basename = os.path.splitext(os.path.basename(name))[0] else: basename = '' - if opts.enable_pnginfo: # append info before save - image.info = existing_pnginfo - image.info["extras"] = info + infotext = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in pp.info.items() if v is not None]) + + if opts.enable_pnginfo: + pp.image.info = existing_pnginfo + pp.image.info["postprocessing"] = infotext if save_output: - # Add upscaler name as a suffix. - suffix = f"-{shared.sd_upscalers[extras_upscaler_1].name}" if shared.opts.use_upscaler_name_as_suffix else "" - # Add second upscaler if applicable. - if suffix and extras_upscaler_2 and extras_upscaler_2_visibility: - suffix += f"-{shared.sd_upscalers[extras_upscaler_2].name}" + images.save_image(pp.image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=pp.info, short_filename=True, no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None) - images.save_image(image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True, - no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None, suffix=suffix) - - if extras_mode != 2 or show_extras_results : - outputs.append(image) + if extras_mode != 2 or show_extras_results: + outputs.append(pp.image) devices.torch_gc() - return outputs, ui_components.plaintext_to_html(info), '' + return outputs, ui_common.plaintext_to_html(infotext), '' -def clear_cache(): - cached_images.clear() +def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True): + """old handler for API""" + args = scripts.scripts_postproc.create_args_for_run({ + "Upscale": { + "upscale_mode": resize_mode, + "upscale_by": upscaling_resize, + "upscale_to_width": upscaling_resize_w, + "upscale_to_height": upscaling_resize_h, + "upscale_crop": upscaling_crop, + "upscaler_1_name": extras_upscaler_1, + "upscaler_2_name": extras_upscaler_2, + "upscaler_2_visibility": extras_upscaler_2_visibility, + }, + "GFPGAN": { + "gfpgan_visibility": gfpgan_visibility, + }, + "CodeFormer": { + "codeformer_visibility": codeformer_visibility, + "codeformer_weight": codeformer_weight, + }, + }) + + return run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output=save_output) diff --git a/modules/scripts.py b/modules/scripts.py index 4ffc369b..03907a63 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -7,7 +7,7 @@ from collections import namedtuple import gradio as gr from modules.processing import StableDiffusionProcessing -from modules import shared, paths, script_callbacks, extensions, script_loading +from modules import shared, paths, script_callbacks, extensions, script_loading, scripts_postprocessing AlwaysVisible = object() @@ -150,8 +150,10 @@ def basedir(): return current_basedir -scripts_data = [] ScriptFile = namedtuple("ScriptFile", ["basedir", "filename", "path"]) + +scripts_data = [] +postprocessing_scripts_data = [] ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir", "module"]) @@ -190,23 +192,31 @@ def list_files_with_name(filename): def load_scripts(): global current_basedir scripts_data.clear() + postprocessing_scripts_data.clear() script_callbacks.clear_callbacks() scripts_list = list_scripts("scripts", ".py") syspath = sys.path + def register_scripts_from_module(module): + for key, script_class in module.__dict__.items(): + if type(script_class) != type: + continue + + if issubclass(script_class, Script): + scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module)) + elif issubclass(script_class, scripts_postprocessing.ScriptPostprocessing): + postprocessing_scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module)) + for scriptfile in sorted(scripts_list): try: if scriptfile.basedir != paths.script_path: sys.path = [scriptfile.basedir] + sys.path current_basedir = scriptfile.basedir - module = script_loading.load_module(scriptfile.path) - - for key, script_class in module.__dict__.items(): - if type(script_class) == type and issubclass(script_class, Script): - scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module)) + script_module = script_loading.load_module(scriptfile.path) + register_scripts_from_module(script_module) except Exception: print(f"Error loading script: {scriptfile.filename}", file=sys.stderr) @@ -413,6 +423,7 @@ class ScriptRunner: scripts_txt2img = ScriptRunner() scripts_img2img = ScriptRunner() +scripts_postproc = scripts_postprocessing.ScriptPostprocessingRunner() scripts_current: ScriptRunner = None @@ -423,12 +434,13 @@ def reload_script_body_only(): def reload_scripts(): - global scripts_txt2img, scripts_img2img + global scripts_txt2img, scripts_img2img, scripts_postproc load_scripts() scripts_txt2img = ScriptRunner() scripts_img2img = ScriptRunner() + scripts_postproc = scripts_postprocessing.ScriptPostprocessingRunner() def IOComponent_init(self, *args, **kwargs): diff --git a/modules/scripts_postprocessing.py b/modules/scripts_postprocessing.py new file mode 100644 index 00000000..25de02d0 --- /dev/null +++ b/modules/scripts_postprocessing.py @@ -0,0 +1,147 @@ +import os +import gradio as gr + +from modules import errors, shared + + +class PostprocessedImage: + def __init__(self, image): + self.image = image + self.info = {} + + +class ScriptPostprocessing: + filename = None + controls = None + args_from = None + args_to = None + + order = 1000 + """scripts will be ordred by this value in postprocessing UI""" + + name = None + """this function should return the title of the script.""" + + group = None + """A gr.Group component that has all script's UI inside it""" + + def ui(self): + """ + This function should create gradio UI elements. See https://gradio.app/docs/#components + The return value should be a dictionary that maps parameter names to components used in processing. + Values of those components will be passed to process() function. + """ + + pass + + def process(self, pp: PostprocessedImage, **args): + """ + This function is called to postprocess the image. + args contains a dictionary with all values returned by components from ui() + """ + + pass + + def image_changed(self): + pass + + +def wrap_call(func, filename, funcname, *args, default=None, **kwargs): + try: + res = func(*args, **kwargs) + return res + except Exception as e: + errors.display(e, f"calling {filename}/{funcname}") + + return default + + +class ScriptPostprocessingRunner: + def __init__(self): + self.scripts = None + self.ui_created = False + + def initialize_scripts(self, scripts_data): + self.scripts = [] + + for script_class, path, basedir, script_module in scripts_data: + script: ScriptPostprocessing = script_class() + script.filename = path + + self.scripts.append(script) + + def create_script_ui(self, script, inputs): + script.args_from = len(inputs) + script.args_to = len(inputs) + + script.controls = wrap_call(script.ui, script.filename, "ui") + + for control in script.controls.values(): + control.custom_script_source = os.path.basename(script.filename) + + inputs += list(script.controls.values()) + script.args_to = len(inputs) + + def scripts_in_preferred_order(self): + if self.scripts is None: + import modules.scripts + self.initialize_scripts(modules.scripts.postprocessing_scripts_data) + + scripts_order = [x.lower().strip() for x in shared.opts.postprocessing_scipts_order.split(",")] + + def script_score(name): + name = name.lower() + for i, possible_match in enumerate(scripts_order): + if possible_match in name: + return i + + return len(self.scripts) + + script_scores = {script.name: (script_score(script.name), script.order, script.name, original_index) for original_index, script in enumerate(self.scripts)} + + return sorted(self.scripts, key=lambda x: script_scores[x.name]) + + def setup_ui(self): + inputs = [] + + for script in self.scripts_in_preferred_order(): + with gr.Box() as group: + self.create_script_ui(script, inputs) + + script.group = group + + self.ui_created = True + return inputs + + def run(self, pp: PostprocessedImage, args): + for script in self.scripts_in_preferred_order(): + shared.state.job = script.name + + script_args = args[script.args_from:script.args_to] + + process_args = {} + for (name, component), value in zip(script.controls.items(), script_args): + process_args[name] = value + + script.process(pp, **process_args) + + def create_args_for_run(self, scripts_args): + if not self.ui_created: + with gr.Blocks(analytics_enabled=False): + self.setup_ui() + + scripts = self.scripts_in_preferred_order() + args = [None] * max([x.args_to for x in scripts]) + + for script in scripts: + script_args_dict = scripts_args.get(script.name, None) + if script_args_dict is not None: + + for i, name in enumerate(script.controls): + args[script.args_from + i] = script_args_dict.get(name, None) + + return args + + def image_changed(self): + for script in self.scripts_in_preferred_order(): + script.image_changed() diff --git a/modules/shared.py b/modules/shared.py index cd78e50a..cb73bf31 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -474,6 +474,11 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters" 'always_discard_next_to_last_sigma': OptionInfo(False, "Always discard next-to-last sigma"), })) +options_templates.update(options_section(('postprocessing', "Postprocessing"), { + 'postprocessing_scipts_order': OptionInfo("upscale, gfpgan, codeformer", "Postprocessing operation order"), + 'upscaling_max_images_in_cache': OptionInfo(5, "Maximum number of images in upscaling cache", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), +})) + options_templates.update(options_section((None, "Hidden options"), { "disabled_extensions": OptionInfo([], "Disable those extensions"), "sd_checkpoint_hash": OptionInfo("", "SHA256 hash of the current checkpoint"), diff --git a/modules/ui.py b/modules/ui.py index 4116e167..8cb8e613 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -20,7 +20,7 @@ import numpy as np from PIL import Image, PngImagePlugin from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call -from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, postprocessing, ui_components +from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, postprocessing, ui_components, ui_common, ui_postprocessing from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML from modules.paths import script_path @@ -86,7 +86,6 @@ css_hide_progressbar = """ random_symbol = '\U0001f3b2\ufe0f' # 🎲️ reuse_symbol = '\u267b\ufe0f' # ♻️ paste_symbol = '\u2199\ufe0f' # ↙ -folder_symbol = '\U0001f4c2' # 📂 refresh_symbol = '\U0001f504' # 🔄 save_style_symbol = '\U0001f4be' # 💾 apply_style_symbol = '\U0001f4cb' # 📋 @@ -95,7 +94,7 @@ extra_networks_symbol = '\U0001F3B4' # 🎴 def plaintext_to_html(text): - return ui_components.plaintext_to_html(text) + return ui_common.plaintext_to_html(text) def send_gradio_gallery_to_image(x): @@ -103,70 +102,6 @@ def send_gradio_gallery_to_image(x): return None return image_from_url_text(x[0]) -def save_files(js_data, images, do_make_zip, index): - import csv - filenames = [] - fullfns = [] - - #quick dictionary to class object conversion. Its necessary due apply_filename_pattern requiring it - class MyObject: - def __init__(self, d=None): - if d is not None: - for key, value in d.items(): - setattr(self, key, value) - - data = json.loads(js_data) - - p = MyObject(data) - path = opts.outdir_save - save_to_dirs = opts.use_save_to_dirs_for_ui - extension: str = opts.samples_format - start_index = 0 - - if index > -1 and opts.save_selected_only and (index >= data["index_of_first_image"]): # ensures we are looking at a specific non-grid picture, and we have save_selected_only - - images = [images[index]] - start_index = index - - os.makedirs(opts.outdir_save, exist_ok=True) - - with open(os.path.join(opts.outdir_save, "log.csv"), "a", encoding="utf8", newline='') as file: - at_start = file.tell() == 0 - writer = csv.writer(file) - if at_start: - writer.writerow(["prompt", "seed", "width", "height", "sampler", "cfgs", "steps", "filename", "negative_prompt"]) - - for image_index, filedata in enumerate(images, start_index): - image = image_from_url_text(filedata) - - is_grid = image_index < p.index_of_first_image - i = 0 if is_grid else (image_index - p.index_of_first_image) - - fullfn, txt_fullfn = save_image(image, path, "", seed=p.all_seeds[i], prompt=p.all_prompts[i], extension=extension, info=p.infotexts[image_index], grid=is_grid, p=p, save_to_dirs=save_to_dirs) - - filename = os.path.relpath(fullfn, path) - filenames.append(filename) - fullfns.append(fullfn) - if txt_fullfn: - filenames.append(os.path.basename(txt_fullfn)) - fullfns.append(txt_fullfn) - - writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler_name"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]]) - - # Make Zip - if do_make_zip: - zip_filepath = os.path.join(path, "images.zip") - - from zipfile import ZipFile - with ZipFile(zip_filepath, "w") as zip_file: - for i in range(len(fullfns)): - with open(fullfns[i], mode="rb") as f: - zip_file.writestr(filenames[i], f.read()) - fullfns.insert(0, zip_filepath) - - return gr.File.update(value=fullfns, visible=True), plaintext_to_html(f"Saved: {filenames[0]}") - - def visit(x, func, path=""): if hasattr(x, 'children'): for c in x.children: @@ -444,19 +379,6 @@ def apply_setting(key, value): opts.save(shared.config_filename) return getattr(opts, key) - -def update_generation_info(generation_info, html_info, img_index): - try: - generation_info = json.loads(generation_info) - if img_index < 0 or img_index >= len(generation_info["infotexts"]): - return html_info, gr.update() - return plaintext_to_html(generation_info["infotexts"][img_index]), gr.update() - except Exception: - pass - # if the json parse or anything else fails, just return the old html_info - return html_info, gr.update() - - def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id): def refresh(): refresh_method() @@ -477,107 +399,7 @@ def create_refresh_button(refresh_component, refresh_method, refreshed_args, ele def create_output_panel(tabname, outdir): - def open_folder(f): - if not os.path.exists(f): - print(f'Folder "{f}" does not exist. After you create an image, the folder will be created.') - return - elif not os.path.isdir(f): - print(f""" -WARNING -An open_folder request was made with an argument that is not a folder. -This could be an error or a malicious attempt to run code on your computer. -Requested path was: {f} -""", file=sys.stderr) - return - - if not shared.cmd_opts.hide_ui_dir_config: - path = os.path.normpath(f) - if platform.system() == "Windows": - os.startfile(path) - elif platform.system() == "Darwin": - sp.Popen(["open", path]) - elif "microsoft-standard-WSL2" in platform.uname().release: - sp.Popen(["wsl-open", path]) - else: - sp.Popen(["xdg-open", path]) - - with gr.Column(variant='panel', elem_id=f"{tabname}_results"): - with gr.Group(elem_id=f"{tabname}_gallery_container"): - result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery").style(grid=4) - - generation_info = None - with gr.Column(): - with gr.Row(elem_id=f"image_buttons_{tabname}"): - open_folder_button = gr.Button(folder_symbol, elem_id="hidden_element" if shared.cmd_opts.hide_ui_dir_config else f'open_folder_{tabname}') - - if tabname != "extras": - save = gr.Button('Save', elem_id=f'save_{tabname}') - save_zip = gr.Button('Zip', elem_id=f'save_zip_{tabname}') - - buttons = parameters_copypaste.create_buttons(["img2img", "inpaint", "extras"]) - - open_folder_button.click( - fn=lambda: open_folder(opts.outdir_samples or outdir), - inputs=[], - outputs=[], - ) - - if tabname != "extras": - with gr.Row(): - download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False, elem_id=f'download_files_{tabname}') - - with gr.Group(): - html_info = gr.HTML(elem_id=f'html_info_{tabname}') - html_log = gr.HTML(elem_id=f'html_log_{tabname}') - - generation_info = gr.Textbox(visible=False, elem_id=f'generation_info_{tabname}') - if tabname == 'txt2img' or tabname == 'img2img': - generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button") - generation_info_button.click( - fn=update_generation_info, - _js="function(x, y, z){ return [x, y, selected_gallery_index()] }", - inputs=[generation_info, html_info, html_info], - outputs=[html_info, html_info], - ) - - save.click( - fn=wrap_gradio_call(save_files), - _js="(x, y, z, w) => [x, y, false, selected_gallery_index()]", - inputs=[ - generation_info, - result_gallery, - html_info, - html_info, - ], - outputs=[ - download_files, - html_log, - ], - show_progress=False, - ) - - save_zip.click( - fn=wrap_gradio_call(save_files), - _js="(x, y, z, w) => [x, y, true, selected_gallery_index()]", - inputs=[ - generation_info, - result_gallery, - html_info, - html_info, - ], - outputs=[ - download_files, - html_log, - ] - ) - - else: - html_info_x = gr.HTML(elem_id=f'html_info_x_{tabname}') - html_info = gr.HTML(elem_id=f'html_info_{tabname}') - html_log = gr.HTML(elem_id=f'html_log_{tabname}') - - parameters_copypaste.bind_buttons(buttons, result_gallery, "txt2img" if tabname == "txt2img" else None) - return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log + return ui_common.create_output_panel(tabname, outdir) def create_sampler_and_steps_selection(choices, tabname): @@ -1106,86 +928,7 @@ def create_ui(): modules.scripts.scripts_current = None with gr.Blocks(analytics_enabled=False) as extras_interface: - with gr.Row().style(equal_height=False): - with gr.Column(variant='compact'): - with gr.Tabs(elem_id="mode_extras"): - with gr.TabItem('Single Image', elem_id="extras_single_tab"): - extras_image = gr.Image(label="Source", source="upload", interactive=True, type="pil", elem_id="extras_image") - - with gr.TabItem('Batch Process', elem_id="extras_batch_process_tab"): - image_batch = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file", elem_id="extras_image_batch") - - with gr.TabItem('Batch from Directory', elem_id="extras_batch_directory_tab"): - extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, placeholder="A directory on the same machine where the server is running.", elem_id="extras_batch_input_dir") - extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.", elem_id="extras_batch_output_dir") - show_extras_results = gr.Checkbox(label='Show result images', value=True, elem_id="extras_show_extras_results") - - submit = gr.Button('Generate', elem_id="extras_generate", variant='primary') - - with gr.Tabs(elem_id="extras_resize_mode"): - with gr.TabItem('Scale by', elem_id="extras_scale_by_tab"): - upscaling_resize = gr.Slider(minimum=1.0, maximum=8.0, step=0.05, label="Resize", value=4, elem_id="extras_upscaling_resize") - with gr.TabItem('Scale to', elem_id="extras_scale_to_tab"): - with gr.Group(): - with gr.Row(): - upscaling_resize_w = gr.Number(label="Width", value=512, precision=0, elem_id="extras_upscaling_resize_w") - upscaling_resize_h = gr.Number(label="Height", value=512, precision=0, elem_id="extras_upscaling_resize_h") - upscaling_crop = gr.Checkbox(label='Crop to fit', value=True, elem_id="extras_upscaling_crop") - - with gr.Group(): - extras_upscaler_1 = gr.Radio(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") - - with gr.Group(): - extras_upscaler_2 = gr.Radio(label='Upscaler 2', elem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") - extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=1, elem_id="extras_upscaler_2_visibility") - - with gr.Group(): - gfpgan_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="GFPGAN visibility", value=0, interactive=modules.gfpgan_model.have_gfpgan, elem_id="extras_gfpgan_visibility") - - with gr.Group(): - codeformer_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer visibility", value=0, interactive=modules.codeformer_model.have_codeformer, elem_id="extras_codeformer_visibility") - codeformer_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer weight (0 = maximum effect, 1 = minimum effect)", value=0, interactive=modules.codeformer_model.have_codeformer, elem_id="extras_codeformer_weight") - - with gr.Group(): - upscale_before_face_fix = gr.Checkbox(label='Upscale Before Restoring Faces', value=False, elem_id="extras_upscale_before_face_fix") - - result_images, html_info_x, html_info, html_log = create_output_panel("extras", opts.outdir_extras_samples) - - submit.click( - fn=wrap_gradio_gpu_call(postprocessing.run_postprocessing, extra_outputs=[None, '']), - _js="get_extras_tab_index", - inputs=[ - dummy_component, - dummy_component, - extras_image, - image_batch, - extras_batch_input_dir, - extras_batch_output_dir, - show_extras_results, - gfpgan_visibility, - codeformer_visibility, - codeformer_weight, - upscaling_resize, - upscaling_resize_w, - upscaling_resize_h, - upscaling_crop, - extras_upscaler_1, - extras_upscaler_2, - extras_upscaler_2_visibility, - upscale_before_face_fix, - ], - outputs=[ - result_images, - html_info_x, - html_info, - ] - ) - parameters_copypaste.add_paste_fields("extras", extras_image, None) - - extras_image.change( - fn=postprocessing.clear_cache, - inputs=[], outputs=[] - ) + ui_postprocessing.create_ui() with gr.Blocks(analytics_enabled=False) as pnginfo_interface: with gr.Row().style(equal_height=False): diff --git a/modules/ui_common.py b/modules/ui_common.py new file mode 100644 index 00000000..8ce75b8c --- /dev/null +++ b/modules/ui_common.py @@ -0,0 +1,202 @@ +import json +import html +import os +import platform +import sys + +import gradio as gr +import scipy as sp + +from modules import call_queue, shared +from modules.generation_parameters_copypaste import image_from_url_text +import modules.images + +folder_symbol = '\U0001f4c2' # 📂 + + +def update_generation_info(generation_info, html_info, img_index): + try: + generation_info = json.loads(generation_info) + if img_index < 0 or img_index >= len(generation_info["infotexts"]): + return html_info, gr.update() + return plaintext_to_html(generation_info["infotexts"][img_index]), gr.update() + except Exception: + pass + # if the json parse or anything else fails, just return the old html_info + return html_info, gr.update() + + +def plaintext_to_html(text): + text = "

" + "
\n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "

" + return text + + +def save_files(js_data, images, do_make_zip, index): + import csv + filenames = [] + fullfns = [] + + #quick dictionary to class object conversion. Its necessary due apply_filename_pattern requiring it + class MyObject: + def __init__(self, d=None): + if d is not None: + for key, value in d.items(): + setattr(self, key, value) + + data = json.loads(js_data) + + p = MyObject(data) + path = shared.opts.outdir_save + save_to_dirs = shared.opts.use_save_to_dirs_for_ui + extension: str = shared.opts.samples_format + start_index = 0 + + if index > -1 and shared.opts.save_selected_only and (index >= data["index_of_first_image"]): # ensures we are looking at a specific non-grid picture, and we have save_selected_only + + images = [images[index]] + start_index = index + + os.makedirs(shared.opts.outdir_save, exist_ok=True) + + with open(os.path.join(shared.opts.outdir_save, "log.csv"), "a", encoding="utf8", newline='') as file: + at_start = file.tell() == 0 + writer = csv.writer(file) + if at_start: + writer.writerow(["prompt", "seed", "width", "height", "sampler", "cfgs", "steps", "filename", "negative_prompt"]) + + for image_index, filedata in enumerate(images, start_index): + image = image_from_url_text(filedata) + + is_grid = image_index < p.index_of_first_image + i = 0 if is_grid else (image_index - p.index_of_first_image) + + fullfn, txt_fullfn = modules.images.save_image(image, path, "", seed=p.all_seeds[i], prompt=p.all_prompts[i], extension=extension, info=p.infotexts[image_index], grid=is_grid, p=p, save_to_dirs=save_to_dirs) + + filename = os.path.relpath(fullfn, path) + filenames.append(filename) + fullfns.append(fullfn) + if txt_fullfn: + filenames.append(os.path.basename(txt_fullfn)) + fullfns.append(txt_fullfn) + + writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler_name"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]]) + + # Make Zip + if do_make_zip: + zip_filepath = os.path.join(path, "images.zip") + + from zipfile import ZipFile + with ZipFile(zip_filepath, "w") as zip_file: + for i in range(len(fullfns)): + with open(fullfns[i], mode="rb") as f: + zip_file.writestr(filenames[i], f.read()) + fullfns.insert(0, zip_filepath) + + return gr.File.update(value=fullfns, visible=True), plaintext_to_html(f"Saved: {filenames[0]}") + + +def create_output_panel(tabname, outdir): + from modules import shared + import modules.generation_parameters_copypaste as parameters_copypaste + + def open_folder(f): + if not os.path.exists(f): + print(f'Folder "{f}" does not exist. After you create an image, the folder will be created.') + return + elif not os.path.isdir(f): + print(f""" +WARNING +An open_folder request was made with an argument that is not a folder. +This could be an error or a malicious attempt to run code on your computer. +Requested path was: {f} +""", file=sys.stderr) + return + + if not shared.cmd_opts.hide_ui_dir_config: + path = os.path.normpath(f) + if platform.system() == "Windows": + os.startfile(path) + elif platform.system() == "Darwin": + sp.Popen(["open", path]) + elif "microsoft-standard-WSL2" in platform.uname().release: + sp.Popen(["wsl-open", path]) + else: + sp.Popen(["xdg-open", path]) + + with gr.Column(variant='panel', elem_id=f"{tabname}_results"): + with gr.Group(elem_id=f"{tabname}_gallery_container"): + result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery").style(grid=4) + + generation_info = None + with gr.Column(): + with gr.Row(elem_id=f"image_buttons_{tabname}"): + open_folder_button = gr.Button(folder_symbol, elem_id="hidden_element" if shared.cmd_opts.hide_ui_dir_config else f'open_folder_{tabname}') + + if tabname != "extras": + save = gr.Button('Save', elem_id=f'save_{tabname}') + save_zip = gr.Button('Zip', elem_id=f'save_zip_{tabname}') + + buttons = parameters_copypaste.create_buttons(["img2img", "inpaint", "extras"]) + + open_folder_button.click( + fn=lambda: open_folder(shared.opts.outdir_samples or outdir), + inputs=[], + outputs=[], + ) + + if tabname != "extras": + with gr.Row(): + download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False, elem_id=f'download_files_{tabname}') + + with gr.Group(): + html_info = gr.HTML(elem_id=f'html_info_{tabname}') + html_log = gr.HTML(elem_id=f'html_log_{tabname}') + + generation_info = gr.Textbox(visible=False, elem_id=f'generation_info_{tabname}') + if tabname == 'txt2img' or tabname == 'img2img': + generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button") + generation_info_button.click( + fn=update_generation_info, + _js="function(x, y, z){ return [x, y, selected_gallery_index()] }", + inputs=[generation_info, html_info, html_info], + outputs=[html_info, html_info], + ) + + save.click( + fn=call_queue.wrap_gradio_call(save_files), + _js="(x, y, z, w) => [x, y, false, selected_gallery_index()]", + inputs=[ + generation_info, + result_gallery, + html_info, + html_info, + ], + outputs=[ + download_files, + html_log, + ], + show_progress=False, + ) + + save_zip.click( + fn=call_queue.wrap_gradio_call(save_files), + _js="(x, y, z, w) => [x, y, true, selected_gallery_index()]", + inputs=[ + generation_info, + result_gallery, + html_info, + html_info, + ], + outputs=[ + download_files, + html_log, + ] + ) + + else: + html_info_x = gr.HTML(elem_id=f'html_info_x_{tabname}') + html_info = gr.HTML(elem_id=f'html_info_{tabname}') + html_log = gr.HTML(elem_id=f'html_log_{tabname}') + + parameters_copypaste.bind_buttons(buttons, result_gallery, "txt2img" if tabname == "txt2img" else None) + return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log diff --git a/modules/ui_components.py b/modules/ui_components.py index 989cc87b..9aec3097 100644 --- a/modules/ui_components.py +++ b/modules/ui_components.py @@ -1,5 +1,3 @@ -import html - import gradio as gr @@ -50,7 +48,3 @@ class FormColorPicker(gr.ColorPicker, gr.components.FormComponent): def get_block_name(self): return "colorpicker" - -def plaintext_to_html(text): - text = "

" + "
\n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "

" - return text diff --git a/modules/ui_postprocessing.py b/modules/ui_postprocessing.py new file mode 100644 index 00000000..b418d955 --- /dev/null +++ b/modules/ui_postprocessing.py @@ -0,0 +1,57 @@ +import gradio as gr +from modules import scripts_postprocessing, scripts, shared, gfpgan_model, codeformer_model, ui_common, postprocessing, call_queue +import modules.generation_parameters_copypaste as parameters_copypaste + + +def create_ui(): + tab_index = gr.State(value=0) + + with gr.Row().style(equal_height=False, variant='compact'): + with gr.Column(variant='compact'): + with gr.Tabs(elem_id="mode_extras"): + with gr.TabItem('Single Image', elem_id="extras_single_tab") as tab_single: + extras_image = gr.Image(label="Source", source="upload", interactive=True, type="pil", elem_id="extras_image") + + with gr.TabItem('Batch Process', elem_id="extras_batch_process_tab") as tab_batch: + image_batch = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file", elem_id="extras_image_batch") + + with gr.TabItem('Batch from Directory', elem_id="extras_batch_directory_tab") as tab_batch_dir: + extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, placeholder="A directory on the same machine where the server is running.", elem_id="extras_batch_input_dir") + extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.", elem_id="extras_batch_output_dir") + show_extras_results = gr.Checkbox(label='Show result images', value=True, elem_id="extras_show_extras_results") + + submit = gr.Button('Generate', elem_id="extras_generate", variant='primary') + + script_inputs = scripts.scripts_postproc.setup_ui() + + with gr.Column(): + result_images, html_info_x, html_info, html_log = ui_common.create_output_panel("extras", shared.opts.outdir_extras_samples) + + tab_single.select(fn=lambda: 0, inputs=[], outputs=[tab_index]) + tab_batch.select(fn=lambda: 1, inputs=[], outputs=[tab_index]) + tab_batch_dir.select(fn=lambda: 2, inputs=[], outputs=[tab_index]) + + submit.click( + fn=call_queue.wrap_gradio_gpu_call(postprocessing.run_postprocessing, extra_outputs=[None, '']), + inputs=[ + tab_index, + extras_image, + image_batch, + extras_batch_input_dir, + extras_batch_output_dir, + show_extras_results, + *script_inputs + ], + outputs=[ + result_images, + html_info_x, + html_info, + ] + ) + + parameters_copypaste.add_paste_fields("extras", extras_image, None) + + extras_image.change( + fn=scripts.scripts_postproc.image_changed, + inputs=[], outputs=[] + ) diff --git a/scripts/postprocessing_codeformer.py b/scripts/postprocessing_codeformer.py new file mode 100644 index 00000000..a7d80d40 --- /dev/null +++ b/scripts/postprocessing_codeformer.py @@ -0,0 +1,36 @@ +from PIL import Image +import numpy as np + +from modules import scripts_postprocessing, codeformer_model +import gradio as gr + +from modules.ui_components import FormRow + + +class ScriptPostprocessingCodeFormer(scripts_postprocessing.ScriptPostprocessing): + name = "CodeFormer" + order = 3000 + + def ui(self): + with FormRow(): + codeformer_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer visibility", value=0, elem_id="extras_codeformer_visibility") + codeformer_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer weight (0 = maximum effect, 1 = minimum effect)", value=0, elem_id="extras_codeformer_weight") + + return { + "codeformer_visibility": codeformer_visibility, + "codeformer_weight": codeformer_weight, + } + + def process(self, pp: scripts_postprocessing.PostprocessedImage, codeformer_visibility, codeformer_weight): + if codeformer_visibility == 0: + return + + restored_img = codeformer_model.codeformer.restore(np.array(pp.image, dtype=np.uint8), w=codeformer_weight) + res = Image.fromarray(restored_img) + + if codeformer_visibility < 1.0: + res = Image.blend(pp.image, res, codeformer_visibility) + + pp.image = res + pp.info["CodeFormer visibility"] = round(codeformer_visibility, 3) + pp.info["CodeFormer weight"] = round(codeformer_weight, 3) diff --git a/scripts/postprocessing_gfpgan.py b/scripts/postprocessing_gfpgan.py new file mode 100644 index 00000000..d854f3f7 --- /dev/null +++ b/scripts/postprocessing_gfpgan.py @@ -0,0 +1,33 @@ +from PIL import Image +import numpy as np + +from modules import scripts_postprocessing, gfpgan_model +import gradio as gr + +from modules.ui_components import FormRow + + +class ScriptPostprocessingGfpGan(scripts_postprocessing.ScriptPostprocessing): + name = "GFPGAN" + order = 2000 + + def ui(self): + with FormRow(): + gfpgan_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="GFPGAN visibility", value=0, elem_id="extras_gfpgan_visibility") + + return { + "gfpgan_visibility": gfpgan_visibility, + } + + def process(self, pp: scripts_postprocessing.PostprocessedImage, gfpgan_visibility): + if gfpgan_visibility == 0: + return + + restored_img = gfpgan_model.gfpgan_fix_faces(np.array(pp.image, dtype=np.uint8)) + res = Image.fromarray(restored_img) + + if gfpgan_visibility < 1.0: + res = Image.blend(pp.image, res, gfpgan_visibility) + + pp.image = res + pp.info["GFPGAN visibility"] = round(gfpgan_visibility, 3) diff --git a/scripts/postprocessing_upscale.py b/scripts/postprocessing_upscale.py new file mode 100644 index 00000000..095d29b2 --- /dev/null +++ b/scripts/postprocessing_upscale.py @@ -0,0 +1,106 @@ +from PIL import Image +import numpy as np + +from modules import scripts_postprocessing, shared +import gradio as gr + +from modules.ui_components import FormRow + + +upscale_cache = {} + + +class ScriptPostprocessingUpscale(scripts_postprocessing.ScriptPostprocessing): + name = "Upscale" + order = 1000 + + def ui(self): + selected_tab = gr.State(value=0) + + with gr.Tabs(elem_id="extras_resize_mode"): + with gr.TabItem('Scale by', elem_id="extras_scale_by_tab") as tab_scale_by: + upscaling_resize = gr.Slider(minimum=1.0, maximum=8.0, step=0.05, label="Resize", value=4, elem_id="extras_upscaling_resize") + + with gr.TabItem('Scale to', elem_id="extras_scale_to_tab") as tab_scale_to: + with FormRow(): + upscaling_resize_w = gr.Number(label="Width", value=512, precision=0, elem_id="extras_upscaling_resize_w") + upscaling_resize_h = gr.Number(label="Height", value=512, precision=0, elem_id="extras_upscaling_resize_h") + upscaling_crop = gr.Checkbox(label='Crop to fit', value=True, elem_id="extras_upscaling_crop") + + with FormRow(): + extras_upscaler_1 = gr.Dropdown(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name) + + with FormRow(): + extras_upscaler_2 = gr.Dropdown(label='Upscaler 2', elem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name) + extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=0.0, elem_id="extras_upscaler_2_visibility") + + tab_scale_by.select(fn=lambda: 0, inputs=[], outputs=[selected_tab]) + tab_scale_to.select(fn=lambda: 1, inputs=[], outputs=[selected_tab]) + + return { + "upscale_mode": selected_tab, + "upscale_by": upscaling_resize, + "upscale_to_width": upscaling_resize_w, + "upscale_to_height": upscaling_resize_h, + "upscale_crop": upscaling_crop, + "upscaler_1_name": extras_upscaler_1, + "upscaler_2_name": extras_upscaler_2, + "upscaler_2_visibility": extras_upscaler_2_visibility, + } + + def upscale(self, image, info, upscaler, upscale_mode, upscale_by, upscale_to_width, upscale_to_height, upscale_crop): + if upscale_mode == 1: + upscale_by = max(upscale_to_width/image.width, upscale_to_height/image.height) + info["Postprocess upscale to"] = f"{upscale_to_width}x{upscale_to_height}" + else: + info["Postprocess upscale by"] = upscale_by + + cache_key = (hash(np.array(image.getdata()).tobytes()), upscaler.name, upscale_mode, upscale_by, upscale_to_width, upscale_to_height, upscale_crop) + cached_image = upscale_cache.pop(cache_key, None) + + if cached_image is not None: + image = cached_image + else: + image = upscaler.scaler.upscale(image, upscale_by, upscaler.data_path) + + upscale_cache[cache_key] = image + if len(upscale_cache) > shared.opts.upscaling_max_images_in_cache: + upscale_cache.pop(next(iter(upscale_cache), None), None) + + if upscale_mode == 1 and upscale_crop: + cropped = Image.new("RGB", (upscale_to_width, upscale_to_height)) + cropped.paste(image, box=(upscale_to_width // 2 - image.width // 2, upscale_to_height // 2 - image.height // 2)) + image = cropped + info["Postprocess crop to"] = f"{image.width}x{image.height}" + + return image + + def process(self, pp: scripts_postprocessing.PostprocessedImage, upscale_mode=1, upscale_by=2.0, upscale_to_width=None, upscale_to_height=None, upscale_crop=False, upscaler_1_name=None, upscaler_2_name=None, upscaler_2_visibility=0.0): + if upscaler_1_name == "None": + upscaler_1_name = None + + upscaler1 = next(iter([x for x in shared.sd_upscalers if x.name == upscaler_1_name]), None) + assert upscaler1 or (upscaler_1_name is None), f'could not find upscaler named {upscaler_1_name}' + + if not upscaler1: + return + + if upscaler_2_name == "None": + upscaler_2_name = None + + upscaler2 = next(iter([x for x in shared.sd_upscalers if x.name == upscaler_2_name and x.name != "None"]), None) + assert upscaler2 or (upscaler_2_name is None), f'could not find upscaler named {upscaler_2_name}' + + upscaled_image = self.upscale(pp.image, pp.info, upscaler1, upscale_mode, upscale_by, upscale_to_width, upscale_to_height, upscale_crop) + pp.info[f"Postprocess upscaler"] = upscaler1.name + + if upscaler2 and upscaler_2_visibility > 0: + second_upscale = self.upscale(pp.image, pp.info, upscaler2, upscale_mode, upscale_by, upscale_to_width, upscale_to_height, upscale_crop) + upscaled_image = Image.blend(upscaled_image, second_upscale, upscaler_2_visibility) + + pp.info[f"Postprocess upscaler 2"] = upscaler2.name + + pp.image = upscaled_image + + def image_changed(self): + upscale_cache.clear() From 669dbd9725b3a285503e093a75c0dfa332073d8a Mon Sep 17 00:00:00 2001 From: Shondoit Date: Mon, 23 Jan 2023 09:54:42 +0100 Subject: [PATCH 014/151] Fix dark mode Fixes #7048 Co-Authored-By: J.J. Tolton --- modules/ui.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/ui.py b/modules/ui.py index eb4b7e6b..43bcb7e5 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1942,11 +1942,11 @@ def reload_javascript(): if cmd_opts.theme is not None: inline += f"set_theme('{cmd_opts.theme}');" - head += f'\n' - for script in modules.scripts.list_scripts("javascript", ".js"): head += f'\n' + head += f'\n' + def template_response(*args, **kwargs): res = shared.GradioTemplateResponseOriginal(*args, **kwargs) res.body = res.body.replace(b'', f'{head}'.encode("utf8")) From fabdae089e476c66eba3b0562e4e1881891804b2 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 23 Jan 2023 14:42:49 +0300 Subject: [PATCH 015/151] add missing import to previous commit --- modules/ui.py | 1 + 1 file changed, 1 insertion(+) diff --git a/modules/ui.py b/modules/ui.py index 8cb8e613..6b5dfd61 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -41,6 +41,7 @@ from modules.sd_samplers import samplers, samplers_for_img2img from modules.textual_inversion import textual_inversion import modules.hypernetworks.ui from modules.generation_parameters_copypaste import image_from_url_text +import modules.extras warnings.filterwarnings("default" if opts.show_warnings else "ignore", category=UserWarning) From 41265a026de699cc223ca5b76c69b4e8e74aa7c1 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 23 Jan 2023 14:50:20 +0300 Subject: [PATCH 016/151] third time's the charm --- modules/extras.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/extras.py b/modules/extras.py index f04ddfc2..36123aa5 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -7,7 +7,7 @@ import torch import tqdm from modules import shared, images, sd_models, sd_vae -from modules.ui import plaintext_to_html +from modules.ui_common import plaintext_to_html import gradio as gr import safetensors.torch From 194cbd065e4644e986889b78a5a949e075b610e8 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 23 Jan 2023 15:50:32 +0300 Subject: [PATCH 017/151] fix open directory button failing --- modules/ui.py | 1 - modules/ui_common.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/modules/ui.py b/modules/ui.py index 94d4a80a..85ae62c7 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -5,7 +5,6 @@ import mimetypes import os import platform import random -import subprocess as sp import sys import tempfile import time diff --git a/modules/ui_common.py b/modules/ui_common.py index 8ce75b8c..9405ac1f 100644 --- a/modules/ui_common.py +++ b/modules/ui_common.py @@ -5,7 +5,7 @@ import platform import sys import gradio as gr -import scipy as sp +import subprocess as sp from modules import call_queue, shared from modules.generation_parameters_copypaste import image_from_url_text From 59146621e256269b85feb536edeb745da20daf68 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 23 Jan 2023 16:40:20 +0300 Subject: [PATCH 018/151] better support for xformers flash attention on older versions of torch --- modules/errors.py | 12 +++++++++ modules/sd_hijack_optimizations.py | 42 +++++++++++++----------------- 2 files changed, 30 insertions(+), 24 deletions(-) diff --git a/modules/errors.py b/modules/errors.py index a10e8708..f6b80dbb 100644 --- a/modules/errors.py +++ b/modules/errors.py @@ -24,6 +24,18 @@ See https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#stable """) +already_displayed = {} + + +def display_once(e: Exception, task): + if task in already_displayed: + return + + display(e, task) + + already_displayed[task] = 1 + + def run(code, task): try: code() diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 9967359b..74452709 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -9,7 +9,7 @@ from torch import einsum from ldm.util import default from einops import rearrange -from modules import shared +from modules import shared, errors from modules.hypernetworks import hypernetwork from .sub_quadratic_attention import efficient_dot_product_attention @@ -279,6 +279,21 @@ def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_ ) +def get_xformers_flash_attention_op(q, k, v): + if not shared.cmd_opts.xformers_flash_attention: + return None + + try: + flash_attention_op = xformers.ops.MemoryEfficientAttentionFlashAttentionOp + fw, bw = flash_attention_op + if fw.supports(xformers.ops.fmha.Inputs(query=q, key=k, value=v, attn_bias=None)): + return flash_attention_op + except Exception as e: + errors.display_once(e, "enabling flash attention") + + return None + + def xformers_attention_forward(self, x, context=None, mask=None): h = self.heads q_in = self.to_q(x) @@ -291,18 +306,7 @@ def xformers_attention_forward(self, x, context=None, mask=None): q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in)) del q_in, k_in, v_in - if shared.cmd_opts.xformers_flash_attention: - op = xformers.ops.MemoryEfficientAttentionFlashAttentionOp - fw, bw = op - if not fw.supports(xformers.ops.fmha.Inputs(query=q, key=k, value=v, attn_bias=None)): - # print('xformers_attention_forward', q.shape, k.shape, v.shape) - # Flash Attention is not availabe for the input arguments. - # Fallback to default xFormers' backend. - op = None - else: - op = None - - out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=op) + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=get_xformers_flash_attention_op(q, k, v)) out = rearrange(out, 'b n h d -> b n (h d)', h=h) return self.to_out(out) @@ -377,17 +381,7 @@ def xformers_attnblock_forward(self, x): q = q.contiguous() k = k.contiguous() v = v.contiguous() - if shared.cmd_opts.xformers_flash_attention: - op = xformers.ops.MemoryEfficientAttentionFlashAttentionOp - fw, bw = op - if not fw.supports(xformers.ops.fmha.Inputs(query=q, key=k, value=v)): - # print('xformers_attnblock_forward', q.shape, k.shape, v.shape) - # Flash Attention is not availabe for the input arguments. - # Fallback to default xFormers' backend. - op = None - else: - op = None - out = xformers.ops.memory_efficient_attention(q, k, v, op=op) + out = xformers.ops.memory_efficient_attention(q, k, v, op=get_xformers_flash_attention_op(q, k, v)) out = rearrange(out, 'b (h w) c -> b c h w', h=h) out = self.proj_out(out) return x + out From 925dd09c91e7338aef72e4ec99d67b8b57280215 Mon Sep 17 00:00:00 2001 From: Vladimir Mandic Date: Mon, 23 Jan 2023 09:03:17 -0500 Subject: [PATCH 019/151] improve interrogate --- modules/interrogate.py | 29 +++++++++++++++++------------ modules/shared.py | 1 + 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/modules/interrogate.py b/modules/interrogate.py index 19938cbb..1d1ac572 100644 --- a/modules/interrogate.py +++ b/modules/interrogate.py @@ -20,6 +20,7 @@ Category = namedtuple("Category", ["name", "topn", "items"]) re_topn = re.compile(r"\.top(\d+)\.") +category_types = ["artists", "flavors", "mediums", "movements"] def download_default_clip_interrogate_categories(content_dir): print("Downloading CLIP categories...") @@ -27,12 +28,8 @@ def download_default_clip_interrogate_categories(content_dir): tmpdir = content_dir + "_tmp" try: os.makedirs(tmpdir) - - torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/artists.txt", os.path.join(tmpdir, "artists.txt")) - torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/flavors.txt", os.path.join(tmpdir, "flavors.top3.txt")) - torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/mediums.txt", os.path.join(tmpdir, "mediums.txt")) - torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/movements.txt", os.path.join(tmpdir, "movements.txt")) - + for category_type in category_types: + torch.hub.download_url_to_file(f"https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/{category_type}.txt", os.path.join(tmpdir, f"{category_type}.txt")) os.rename(tmpdir, content_dir) except Exception as e: @@ -51,12 +48,13 @@ class InterrogateModels: def __init__(self, content_dir): self.loaded_categories = None + self.selected_categories = [] self.content_dir = content_dir self.running_on_cpu = devices.device_interrogate == torch.device("cpu") def categories(self): - if self.loaded_categories is not None: - return self.loaded_categories + if self.loaded_categories is not None and self.selected_categories == shared.opts.interrogate_clip_categories: + return self.loaded_categories self.loaded_categories = [] @@ -64,14 +62,19 @@ class InterrogateModels: download_default_clip_interrogate_categories(self.content_dir) if os.path.exists(self.content_dir): - for filename in os.listdir(self.content_dir): + self.selected_categories = shared.opts.interrogate_clip_categories + for category_type in category_types: + if 'all' not in self.selected_categories and category_type not in self.selected_categories: + continue + filename = os.path.join(self.content_dir, f"{category_type}.txt") + if not os.path.isfile(filename): + continue m = re_topn.search(filename) topn = 1 if m is None else int(m.group(1)) - - with open(os.path.join(self.content_dir, filename), "r", encoding="utf8") as file: + with open(filename, "r", encoding="utf8") as file: lines = [x.strip() for x in file.readlines()] - self.loaded_categories.append(Category(name=filename, topn=topn, items=lines)) + self.loaded_categories.append(Category(name=category_type, topn=topn, items=lines)) return self.loaded_categories @@ -139,6 +142,8 @@ class InterrogateModels: def rank(self, image_features, text_array, top_count=1): import clip + devices.torch_gc() + if shared.opts.interrogate_clip_dict_limit != 0: text_array = text_array[0:int(shared.opts.interrogate_clip_dict_limit)] diff --git a/modules/shared.py b/modules/shared.py index a644c0be..63b236c5 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -424,6 +424,7 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"), "interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum description length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}), "interrogate_clip_max_length": OptionInfo(48, "Interrogate: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}), "interrogate_clip_dict_limit": OptionInfo(1500, "CLIP: maximum number of lines in text file (0 = No limit)"), + "interrogate_clip_categories": OptionInfo(modules.interrogate.category_types, "CLIP: select which categories to inquire", gr.CheckboxGroup, lambda: {"choices": modules.interrogate.category_types}), "interrogate_deepbooru_score_threshold": OptionInfo(0.5, "Interrogate: deepbooru score threshold", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}), "deepbooru_sort_alpha": OptionInfo(True, "Interrogate: deepbooru sort alphabetically"), "deepbooru_use_spaces": OptionInfo(False, "use spaces for tags in deepbooru"), From 7ff1ef77dd22f7b38612f91b389237a5dbef2474 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 23 Jan 2023 17:17:31 +0300 Subject: [PATCH 020/151] add a message about new torch/xformers version and a way to upgrade by specifying a commandline flag --- launch.py | 3 ++- webui.py | 26 ++++++++++++++++++++++++++ 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/launch.py b/launch.py index 51c322c0..e7a0b50c 100644 --- a/launch.py +++ b/launch.py @@ -208,6 +208,7 @@ def prepare_environment(): sys.argv, _ = extract_arg(sys.argv, '-f') sys.argv, skip_torch_cuda_test = extract_arg(sys.argv, '--skip-torch-cuda-test') sys.argv, reinstall_xformers = extract_arg(sys.argv, '--reinstall-xformers') + sys.argv, reinstall_torch = extract_arg(sys.argv, '--reinstall-torch') sys.argv, update_check = extract_arg(sys.argv, '--update-check') sys.argv, run_tests, test_dir = extract_opt(sys.argv, '--tests') sys.argv, skip_install = extract_arg(sys.argv, '--skip-install') @@ -219,7 +220,7 @@ def prepare_environment(): print(f"Python {sys.version}") print(f"Commit hash: {commit}") - if not is_installed("torch") or not is_installed("torchvision"): + if reinstall_torch or not is_installed("torch") or not is_installed("torchvision"): run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch") if not skip_torch_cuda_test: diff --git a/webui.py b/webui.py index 7cf5885e..bc2baeab 100644 --- a/webui.py +++ b/webui.py @@ -8,6 +8,7 @@ import re from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.gzip import GZipMiddleware +from packaging import version from modules import import_hook, errors, extra_networks from modules import extra_networks_hypernet, ui_extra_networks_hypernets, ui_extra_networks_textual_inversion @@ -49,7 +50,32 @@ else: server_name = "0.0.0.0" if cmd_opts.listen else None +def check_versions(): + expected_torch_version = "1.13.1" + + if version.parse(torch.__version__) < version.parse(expected_torch_version): + errors.print_error_explanation(f""" +You are running torch {torch.__version__}. +The program is tested to work with torch {expected_torch_version}. +To reinstall the desired version, run with commandline flag --reinstall-torch. +Beware that this will cause a lot of large files to be downloaded. + """.strip()) + + expected_xformers_version = "0.0.16rc425" + if shared.xformers_available: + import xformers + + if version.parse(xformers.__version__) < version.parse(expected_xformers_version): + errors.print_error_explanation(f""" +You are running xformers {xformers.__version__}. +The program is tested to work with xformers {expected_xformers_version}. +To reinstall the desired version, run with commandline flag --reinstall-xformers. + """.strip()) + + def initialize(): + check_versions() + extensions.list_extensions() localization.list_localizations(cmd_opts.localizations_dir) From e8c3d03f7d9966b81458944efb25666b2143153f Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 23 Jan 2023 17:59:58 +0300 Subject: [PATCH 021/151] a possible fix for broken image upscaling --- modules/postprocessing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/postprocessing.py b/modules/postprocessing.py index 8514fea7..09d8e605 100644 --- a/modules/postprocessing.py +++ b/modules/postprocessing.py @@ -67,7 +67,7 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, pp.image.info["postprocessing"] = infotext if save_output: - images.save_image(pp.image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=pp.info, short_filename=True, no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None) + images.save_image(pp.image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=infotext, short_filename=True, no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None) if extras_mode != 2 or show_extras_results: outputs.append(pp.image) From 6e1b296baf7a2cdc0ee747225f1704bd2d45c118 Mon Sep 17 00:00:00 2001 From: Vladimir Mandic Date: Mon, 23 Jan 2023 10:10:59 -0500 Subject: [PATCH 022/151] api-image-format --- modules/api/api.py | 34 ++++++++++++++++++++++++---------- 1 file changed, 24 insertions(+), 10 deletions(-) diff --git a/modules/api/api.py b/modules/api/api.py index 5d60fc0a..b1dd14cc 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -22,6 +22,8 @@ from modules.sd_models import checkpoints_list, find_checkpoint_config from modules.realesrgan_model import get_realesrgan_models from modules import devices from typing import List +import piexif +import piexif.helper def upscaler_to_index(name: str): try: @@ -56,18 +58,30 @@ def decode_base64_to_image(encoding): def encode_pil_to_base64(image): with io.BytesIO() as output_bytes: - # Copy any text-only metadata - use_metadata = False - metadata = PngImagePlugin.PngInfo() - for key, value in image.info.items(): - if isinstance(key, str) and isinstance(value, str): - metadata.add_text(key, value) - use_metadata = True + if opts.samples_format.lower() == 'png': + use_metadata = False + metadata = PngImagePlugin.PngInfo() + for key, value in image.info.items(): + if isinstance(key, str) and isinstance(value, str): + metadata.add_text(key, value) + use_metadata = True + image.save(output_bytes, format="PNG", pnginfo=(metadata if use_metadata else None), quality=opts.jpeg_quality) + + elif opts.samples_format.lower() in ("jpg", "jpeg", "webp"): + parameters = image.info.get('parameters', None) + exif_bytes = piexif.dump({ + "Exif": { piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(parameters or "", encoding="unicode") } + }) + if opts.samples_format.lower() in ("jpg", "jpeg"): + image.save(output_bytes, format="JPEG", exif = exif_bytes, quality=opts.jpeg_quality) + else: + image.save(output_bytes, format="WEBP", exif = exif_bytes, quality=opts.jpeg_quality) + + else: + raise HTTPException(status_code=500, detail="Invalid image format") - image.save( - output_bytes, "PNG", pnginfo=(metadata if use_metadata else None) - ) bytes_data = output_bytes.getvalue() + return base64.b64encode(bytes_data) def api_middleware(app: FastAPI): From e407d1af897a7896d8c81e32dc86e7eb753ce207 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 23 Jan 2023 18:12:51 +0300 Subject: [PATCH 023/151] add support for loras trained on kohya's scripts 0.4.0 (alphas) --- extensions-builtin/Lora/lora.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py index da1797dc..220e64ff 100644 --- a/extensions-builtin/Lora/lora.py +++ b/extensions-builtin/Lora/lora.py @@ -92,6 +92,15 @@ def load_lora(name, filename): keys_failed_to_match.append(key_diffusers) continue + lora_module = lora.modules.get(key, None) + if lora_module is None: + lora_module = LoraUpDownModule() + lora.modules[key] = lora_module + + if lora_key == "alpha": + lora_module.alpha = weight.item() + continue + if type(sd_module) == torch.nn.Linear: module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) elif type(sd_module) == torch.nn.Conv2d: @@ -104,17 +113,12 @@ def load_lora(name, filename): module.to(device=devices.device, dtype=devices.dtype) - lora_module = lora.modules.get(key, None) - if lora_module is None: - lora_module = LoraUpDownModule() - lora.modules[key] = lora_module - if lora_key == "lora_up.weight": lora_module.up = module elif lora_key == "lora_down.weight": lora_module.down = module else: - assert False, f'Bad Lora layer name: {key_diffusers} - must end in lora_up.weight or lora_down.weight' + assert False, f'Bad Lora layer name: {key_diffusers} - must end in lora_up.weight, lora_down.weight or alpha' if len(keys_failed_to_match) > 0: print(f"Failed to match keys when loading Lora {filename}: {keys_failed_to_match}") @@ -161,7 +165,7 @@ def lora_forward(module, input, res): for lora in loaded_loras: module = lora.modules.get(lora_layer_name, None) if module is not None: - res = res + module.up(module.down(input)) * lora.multiplier + res = res + module.up(module.down(input)) * lora.multiplier * module.alpha / module.up.weight.shape[1] return res From dbcb6fac77f642e30d7b00b76cb7164a26dd4b94 Mon Sep 17 00:00:00 2001 From: Guillermo Moreno Date: Mon, 23 Jan 2023 12:14:01 -0300 Subject: [PATCH 024/151] feat(extra-networks): replace icon background with border --- html/image-update.svg | 6 +++++- style.css | 1 - 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/html/image-update.svg b/html/image-update.svg index 525e4fc5..3abf12df 100644 --- a/html/image-update.svg +++ b/html/image-update.svg @@ -1,3 +1,7 @@ - + + + + + diff --git a/style.css b/style.css index ca0a172b..9fb00e49 100644 --- a/style.css +++ b/style.css @@ -837,7 +837,6 @@ footer { display: none; font-size: 0; text-align: -9999; - background-color: #fff; } .extra-network-thumbs .actions .name { From c6f20f72629f3c417f10db2289d131441c6832f5 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 23 Jan 2023 18:52:55 +0300 Subject: [PATCH 025/151] make loras before 0.4.0 ALSO work --- extensions-builtin/Lora/lora.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py index 220e64ff..137e58f7 100644 --- a/extensions-builtin/Lora/lora.py +++ b/extensions-builtin/Lora/lora.py @@ -57,6 +57,7 @@ class LoraUpDownModule: def __init__(self): self.up = None self.down = None + self.alpha = None def assign_lora_names_to_compvis_modules(sd_model): @@ -165,7 +166,7 @@ def lora_forward(module, input, res): for lora in loaded_loras: module = lora.modules.get(lora_layer_name, None) if module is not None: - res = res + module.up(module.down(input)) * lora.multiplier * module.alpha / module.up.weight.shape[1] + res = res + module.up(module.down(input)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0) return res From 04a561c11c9bf9a00d7f9b50ca3f7962aa59ba6e Mon Sep 17 00:00:00 2001 From: Vladimir Mandic Date: Mon, 23 Jan 2023 12:29:23 -0500 Subject: [PATCH 026/151] add option to skip interrogate categories --- modules/interrogate.py | 32 ++++++++++++++++++-------------- modules/shared.py | 2 +- 2 files changed, 19 insertions(+), 15 deletions(-) diff --git a/modules/interrogate.py b/modules/interrogate.py index 1d1ac572..c252b148 100644 --- a/modules/interrogate.py +++ b/modules/interrogate.py @@ -2,6 +2,7 @@ import os import sys import traceback from collections import namedtuple +from pathlib import Path import re import torch @@ -20,12 +21,16 @@ Category = namedtuple("Category", ["name", "topn", "items"]) re_topn = re.compile(r"\.top(\d+)\.") -category_types = ["artists", "flavors", "mediums", "movements"] +def category_types(): + return [f.stem for f in Path(shared.interrogator.content_dir).glob('*.txt')] + def download_default_clip_interrogate_categories(content_dir): print("Downloading CLIP categories...") tmpdir = content_dir + "_tmp" + category_types = ["artists", "flavors", "mediums", "movements"] + try: os.makedirs(tmpdir) for category_type in category_types: @@ -48,33 +53,32 @@ class InterrogateModels: def __init__(self, content_dir): self.loaded_categories = None - self.selected_categories = [] + self.skip_categories = [] self.content_dir = content_dir self.running_on_cpu = devices.device_interrogate == torch.device("cpu") def categories(self): - if self.loaded_categories is not None and self.selected_categories == shared.opts.interrogate_clip_categories: + if not os.path.exists(self.content_dir): + download_default_clip_interrogate_categories(self.content_dir) + + if self.loaded_categories is not None and self.skip_categories == shared.opts.interrogate_clip_skip_categories: return self.loaded_categories self.loaded_categories = [] - if not os.path.exists(self.content_dir): - download_default_clip_interrogate_categories(self.content_dir) - if os.path.exists(self.content_dir): - self.selected_categories = shared.opts.interrogate_clip_categories - for category_type in category_types: - if 'all' not in self.selected_categories and category_type not in self.selected_categories: + self.skip_categories = shared.opts.interrogate_clip_skip_categories + category_types = [] + for filename in Path(self.content_dir).glob('*.txt'): + category_types.append(filename.stem) + if filename.stem in self.skip_categories: continue - filename = os.path.join(self.content_dir, f"{category_type}.txt") - if not os.path.isfile(filename): - continue - m = re_topn.search(filename) + m = re_topn.search(filename.stem) topn = 1 if m is None else int(m.group(1)) with open(filename, "r", encoding="utf8") as file: lines = [x.strip() for x in file.readlines()] - self.loaded_categories.append(Category(name=category_type, topn=topn, items=lines)) + self.loaded_categories.append(Category(name=filename.stem, topn=topn, items=lines)) return self.loaded_categories diff --git a/modules/shared.py b/modules/shared.py index d7a18f6a..5f713bee 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -424,7 +424,7 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"), "interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum description length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}), "interrogate_clip_max_length": OptionInfo(48, "Interrogate: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}), "interrogate_clip_dict_limit": OptionInfo(1500, "CLIP: maximum number of lines in text file (0 = No limit)"), - "interrogate_clip_categories": OptionInfo(modules.interrogate.category_types, "CLIP: select which categories to inquire", gr.CheckboxGroup, lambda: {"choices": modules.interrogate.category_types}), + "interrogate_clip_skip_categories": OptionInfo([], "CLIP: skip inquire categories", gr.CheckboxGroup, lambda: {"choices": modules.interrogate.category_types()}, refresh=modules.interrogate.category_types), "interrogate_deepbooru_score_threshold": OptionInfo(0.5, "Interrogate: deepbooru score threshold", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}), "deepbooru_sort_alpha": OptionInfo(True, "Interrogate: deepbooru sort alphabetically"), "deepbooru_use_spaces": OptionInfo(False, "use spaces for tags in deepbooru"), From 865af20d8a4a823df3c950f5c9c9092a541bc57a Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 23 Jan 2023 21:28:59 +0300 Subject: [PATCH 027/151] suppress A matching Triton is not available message you can all now stop worrying about it --- webui.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/webui.py b/webui.py index bc2baeab..e1565a8d 100644 --- a/webui.py +++ b/webui.py @@ -1,6 +1,5 @@ import os import sys -import threading import time import importlib import signal @@ -10,6 +9,9 @@ from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.gzip import GZipMiddleware from packaging import version +import logging +logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage()) + from modules import import_hook, errors, extra_networks from modules import extra_networks_hypernet, ui_extra_networks_hypernets, ui_extra_networks_textual_inversion from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call From 7b1c7ba87b14da9960d0347269421233f4cb5838 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 23 Jan 2023 23:11:34 +0300 Subject: [PATCH 028/151] add support for apostrophe in extra network names --- html/extra-networks-card.html | 4 ++-- modules/ui_extra_networks.py | 6 ++++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/html/extra-networks-card.html b/html/extra-networks-card.html index 1bdf1d27..aa9fca87 100644 --- a/html/extra-networks-card.html +++ b/html/extra-networks-card.html @@ -1,8 +1,8 @@ -
+
{name} diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 2ddac3d8..8b4f97f8 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -3,6 +3,7 @@ import os.path from modules import shared import gradio as gr import json +import html from modules.generation_parameters_copypaste import image_from_url_text @@ -54,12 +55,13 @@ class ExtraNetworksPage: preview = item.get("preview", None) args = { - "preview_html": "style='background-image: url(" + json.dumps(preview) + ")'" if preview else '', + "preview_html": "style='background-image: url(\"" + html.escape(preview) + "\")'" if preview else '', "prompt": item["prompt"], "tabname": json.dumps(tabname), "local_preview": json.dumps(item["local_preview"]), "name": item["name"], - "allow_negative_prompt": "true" if self.allow_negative_prompt else "false", + "card_clicked": '"' + html.escape(f"""return cardClicked({json.dumps(tabname)}, {item["prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"', + "save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {json.dumps(tabname)}, {json.dumps(item["local_preview"])})""") + '"', } return self.card_page.format(**args) From 5c1cb9263f980641007088a37360fcab01761d37 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 24 Jan 2023 00:24:17 +0300 Subject: [PATCH 029/151] fix BLIP failing to import depending on configuration --- modules/interrogate.py | 3 ++- modules/paths.py | 14 ++++++++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/modules/interrogate.py b/modules/interrogate.py index c252b148..236e6983 100644 --- a/modules/interrogate.py +++ b/modules/interrogate.py @@ -83,7 +83,8 @@ class InterrogateModels: return self.loaded_categories def load_blip_model(self): - import models.blip + with paths.Prioritize("BLIP"): + import models.blip files = modelloader.load_models( model_path=os.path.join(paths.models_path, "BLIP"), diff --git a/modules/paths.py b/modules/paths.py index 4dd03a35..20b3e4d8 100644 --- a/modules/paths.py +++ b/modules/paths.py @@ -38,3 +38,17 @@ for d, must_exist, what, options in path_dirs: else: sys.path.append(d) paths[what] = d + + +class Prioritize: + def __init__(self, name): + self.name = name + self.path = None + + def __enter__(self): + self.path = sys.path.copy() + sys.path = [paths[self.name]] + sys.path + + def __exit__(self, exc_type, exc_val, exc_tb): + sys.path = self.path + self.path = None From 82a28bfe35928e244d4b51d72ce424aff5619b75 Mon Sep 17 00:00:00 2001 From: Mykeehu Date: Mon, 23 Jan 2023 22:36:27 +0100 Subject: [PATCH 030/151] Fix extra network thumbs label color Added white color for labels. --- style.css | 1 + 1 file changed, 1 insertion(+) diff --git a/style.css b/style.css index b2677fa1..ec046f78 100644 --- a/style.css +++ b/style.css @@ -857,6 +857,7 @@ footer { white-space: nowrap; text-overflow: ellipsis; background: rgba(0,0,0,.5); + color: white; } .extra-network-thumbs .card:hover .actions .name { From 45e270dfc853216b2c413f915946f0f2842e57a4 Mon Sep 17 00:00:00 2001 From: Vladimir Mandic Date: Mon, 23 Jan 2023 17:11:22 -0500 Subject: [PATCH 031/151] add image decod exception handling --- modules/api/api.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/modules/api/api.py b/modules/api/api.py index b1dd14cc..e6e31e41 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -53,7 +53,11 @@ def setUpscalers(req: dict): def decode_base64_to_image(encoding): if encoding.startswith("data:image/"): encoding = encoding.split(";")[1].split(",")[1] - return Image.open(BytesIO(base64.b64decode(encoding))) + try: + image = Image.open(BytesIO(base64.b64decode(encoding))) + return image + except Exception as err: + raise HTTPException(status_code=500, detail="Invalid encoded image") def encode_pil_to_base64(image): with io.BytesIO() as output_bytes: From f99352582084890b9167c1bf8699865bea0cef5f Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Mon, 23 Jan 2023 21:50:59 -0500 Subject: [PATCH 032/151] Make SwinIR interruptible --- extensions-builtin/SwinIR/scripts/swinir_model.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/extensions-builtin/SwinIR/scripts/swinir_model.py b/extensions-builtin/SwinIR/scripts/swinir_model.py index 9a74b253..3479760a 100644 --- a/extensions-builtin/SwinIR/scripts/swinir_model.py +++ b/extensions-builtin/SwinIR/scripts/swinir_model.py @@ -8,7 +8,7 @@ from basicsr.utils.download_util import load_file_from_url from tqdm import tqdm from modules import modelloader, devices, script_callbacks, shared -from modules.shared import cmd_opts, opts +from modules.shared import cmd_opts, opts, state from swinir_model_arch import SwinIR as net from swinir_model_arch_v2 import Swin2SR as net2 from modules.upscaler import Upscaler, UpscalerData @@ -145,7 +145,13 @@ def inference(img, model, tile, tile_overlap, window_size, scale): with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar: for h_idx in h_idx_list: + if state.interrupted: + break + for w_idx in w_idx_list: + if state.interrupted: + break + in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile] out_patch = model(in_patch) out_patch_mask = torch.ones_like(out_patch) From 3c47b050367ee220dcfed7be7883878417735614 Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Mon, 23 Jan 2023 22:00:27 -0500 Subject: [PATCH 033/151] Also make SwinIR skippable --- extensions-builtin/SwinIR/scripts/swinir_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/extensions-builtin/SwinIR/scripts/swinir_model.py b/extensions-builtin/SwinIR/scripts/swinir_model.py index 3479760a..e8783bca 100644 --- a/extensions-builtin/SwinIR/scripts/swinir_model.py +++ b/extensions-builtin/SwinIR/scripts/swinir_model.py @@ -145,11 +145,11 @@ def inference(img, model, tile, tile_overlap, window_size, scale): with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar: for h_idx in h_idx_list: - if state.interrupted: + if state.interrupted or state.skipped: break for w_idx in w_idx_list: - if state.interrupted: + if state.interrupted or state.skipped: break in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile] From 078e16e4d33ccbd40ff3ecfbb57ffd33a2a16c47 Mon Sep 17 00:00:00 2001 From: acncagua Date: Tue, 24 Jan 2023 12:21:07 +0900 Subject: [PATCH 034/151] Set Linux xformers 0.0.16RC425 --- launch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/launch.py b/launch.py index e7a0b50c..82094fa0 100644 --- a/launch.py +++ b/launch.py @@ -245,7 +245,7 @@ def prepare_environment(): if not is_installed("xformers"): exit(0) elif platform.system() == "Linux": - run_pip("install xformers", "xformers") + run_pip("install xformers==0.0.16rc425", "xformers") if not is_installed("pyngrok") and ngrok: run_pip("install pyngrok", "ngrok") From f64af77adcd20fabe00e1e642512db9c6742ed23 Mon Sep 17 00:00:00 2001 From: brkirch Date: Mon, 23 Jan 2023 22:49:20 -0500 Subject: [PATCH 035/151] Fix different first gen with Approx NN previews The loading of the model for approx nn live previews can change the internal state of PyTorch, resulting in a different image. This can be avoided by preloading the approx nn model in advance. --- modules/processing.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/modules/processing.py b/modules/processing.py index bc541e2f..3bd590ba 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -13,7 +13,7 @@ from skimage import exposure from typing import Any, Dict, List, Optional import modules.sd_hijack -from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks, extra_networks +from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks, extra_networks, sd_vae_approx from modules.sd_hijack import model_hijack from modules.shared import opts, cmd_opts, state import modules.shared as shared @@ -568,6 +568,10 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: with devices.autocast(): p.init(p.all_prompts, p.all_seeds, p.all_subseeds) + if shared.opts.live_previews_enable and sd_samplers.approximation_indexes.get(shared.opts.show_progress_type, 0) == 1: + # preload approx nn model before sampling for a more deterministic result + sd_vae_approx.model() + if not p.disable_extra_networks: extra_networks.activate(p, extra_network_data) From 42a70d74771e8920f658e741679768ed145dd76a Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 24 Jan 2023 10:05:45 +0300 Subject: [PATCH 036/151] repair sdapi/v1/upscalers returning bogus results --- modules/api/api.py | 16 +++++++++------- modules/api/models.py | 2 +- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/modules/api/api.py b/modules/api/api.py index e6e31e41..da2a5daf 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -375,13 +375,15 @@ class Api: return [{"name": sampler[0], "aliases":sampler[2], "options":sampler[3]} for sampler in sd_samplers.all_samplers] def get_upscalers(self): - upscalers = [] - - for upscaler in shared.sd_upscalers: - u = upscaler.scaler - upscalers.append({"name":u.name, "model_name":u.model_name, "model_path":u.model_path, "model_url":u.model_url}) - - return upscalers + return [ + { + "name": upscaler.name, + "model_name": upscaler.scaler.model_name, + "model_path": upscaler.data_path, + "scale": upscaler.scale, + } + for upscaler in shared.sd_upscalers + ] def get_sd_models(self): return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config(x)} for x in checkpoints_list.values()] diff --git a/modules/api/models.py b/modules/api/models.py index 1eb1fcf1..e562ab54 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -219,7 +219,7 @@ class UpscalerItem(BaseModel): name: str = Field(title="Name") model_name: Optional[str] = Field(title="Model Name") model_path: Optional[str] = Field(title="Path") - model_url: Optional[str] = Field(title="URL") + scale: Optional[float] = Field(title="Scale") class SDModelItem(BaseModel): title: str = Field(title="Title") From 602a1864b05075ca4283986e6f5c7d5bce864e11 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 24 Jan 2023 10:09:30 +0300 Subject: [PATCH 037/151] also return the removed field to sdapi/v1/upscalers because someone might have relied on it existing --- modules/api/api.py | 1 + modules/api/models.py | 1 + 2 files changed, 2 insertions(+) diff --git a/modules/api/api.py b/modules/api/api.py index da2a5daf..25c65e57 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -380,6 +380,7 @@ class Api: "name": upscaler.name, "model_name": upscaler.scaler.model_name, "model_path": upscaler.data_path, + "model_url": None, "scale": upscaler.scale, } for upscaler in shared.sd_upscalers diff --git a/modules/api/models.py b/modules/api/models.py index e562ab54..805bd8f7 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -219,6 +219,7 @@ class UpscalerItem(BaseModel): name: str = Field(title="Name") model_name: Optional[str] = Field(title="Model Name") model_path: Optional[str] = Field(title="Path") + model_url: Optional[str] = Field(title="URL") scale: Optional[float] = Field(title="Scale") class SDModelItem(BaseModel): From d30ac02f28bf5fa1ca5d4ba444180ba9e50b4860 Mon Sep 17 00:00:00 2001 From: EllangoK Date: Tue, 24 Jan 2023 02:21:32 -0500 Subject: [PATCH 038/151] renamed xy to xyz grid this is mostly just so git can detect it properly --- scripts/{xy_grid.py => xyz_grid.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename scripts/{xy_grid.py => xyz_grid.py} (100%) diff --git a/scripts/xy_grid.py b/scripts/xyz_grid.py similarity index 100% rename from scripts/xy_grid.py rename to scripts/xyz_grid.py From 9fc354e1303453bbd865cbede86da4c3273ed14f Mon Sep 17 00:00:00 2001 From: EllangoK Date: Tue, 24 Jan 2023 02:22:40 -0500 Subject: [PATCH 039/151] implements most of xyz grid script --- scripts/xyz_grid.py | 114 ++++++++++++++++++++++++++++++-------------- 1 file changed, 78 insertions(+), 36 deletions(-) diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py index 1a452355..494e8417 100644 --- a/scripts/xyz_grid.py +++ b/scripts/xyz_grid.py @@ -205,26 +205,30 @@ axis_options = [ ] -def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend, include_lone_images, swap_axes_processing_order): +def draw_xyz_grid(p, xs, ys, zs, x_labels, y_labels, z_labels, cell, draw_legend, include_lone_images, swap_axes_processing_order): ver_texts = [[images.GridAnnotation(y)] for y in y_labels] hor_texts = [[images.GridAnnotation(x)] for x in x_labels] + title_texts = [[images.GridAnnotation(z)] for z in z_labels] # Temporary list of all the images that are generated to be populated into the grid. # Will be filled with empty images for any individual step that fails to process properly - image_cache = [None] * (len(xs) * len(ys)) + image_cache = [None] * (len(xs) * len(ys) * len(zs)) processed_result = None cell_mode = "P" cell_size = (1, 1) - state.job_count = len(xs) * len(ys) * p.n_iter + state.job_count = len(xs) * len(ys) * len(zs) * p.n_iter - def process_cell(x, y, ix, iy): + def process_cell(x, y, z, ix, iy, iz): nonlocal image_cache, processed_result, cell_mode, cell_size - state.job = f"{ix + iy * len(xs) + 1} out of {len(xs) * len(ys)}" + def index(ix, iy, iz): + return ix + iy*len(xs) + iz*len(xs)*len(ys) - processed: Processed = cell(x, y) + state.job = f"{index(ix, iy, iz) + 1} out of {len(xs) * len(ys) * len(zs)}" + + processed: Processed = cell(x, y, z) try: # this dereference will throw an exception if the image was not processed @@ -238,33 +242,40 @@ def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend, include_lone_ cell_size = processed_image.size processed_result.images = [Image.new(cell_mode, cell_size)] - image_cache[ix + iy * len(xs)] = processed_image + image_cache[index(ix, iy, iz)] = processed_image if include_lone_images: processed_result.images.append(processed_image) processed_result.all_prompts.append(processed.prompt) processed_result.all_seeds.append(processed.seed) processed_result.infotexts.append(processed.infotexts[0]) except: - image_cache[ix + iy * len(xs)] = Image.new(cell_mode, cell_size) + image_cache[index(ix, iy, iz)] = Image.new(cell_mode, cell_size) if swap_axes_processing_order: for ix, x in enumerate(xs): for iy, y in enumerate(ys): - process_cell(x, y, ix, iy) + for iy, y in enumerate(zs): + process_cell(x, y, z, ix, iy, iz) else: for iy, y in enumerate(ys): for ix, x in enumerate(xs): - process_cell(x, y, ix, iy) + for iz, z in enumerate(zs): + process_cell(x, y, z, ix, iy, iz) if not processed_result: - print("Unexpected error: draw_xy_grid failed to return even a single processed image") + print("Unexpected error: draw_xyz_grid failed to return even a single processed image") return Processed(p, []) - grid = images.image_grid(image_cache, rows=len(ys)) - if draw_legend: - grid = images.draw_grid_annotations(grid, cell_size[0], cell_size[1], hor_texts, ver_texts) - - processed_result.images[0] = grid + for i, title_text in enumerate(title_texts): + start_index = i * len(xs) * len(ys) + end_index = start_index + len(xs) * len(ys) + grid = images.image_grid(image_cache[start_index:end_index], rows=len(ys)) + if draw_legend: + grid = images.draw_grid_annotations(grid, cell_size[0], cell_size[1], hor_texts, ver_texts) + if i == 0: # First position is a placeholder as mentioned above, so it can be directly replaced + processed_result.images[0] = grid + else: + processed_result.images.insert(i, grid) return processed_result @@ -291,7 +302,7 @@ re_range_count_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+ class Script(scripts.Script): def title(self): - return "X/Y plot" + return "X/Y/Z plot" def ui(self, is_img2img): self.current_axis_options = [x for x in axis_options if type(x) == AxisOption or x.is_img2img == is_img2img] @@ -301,24 +312,35 @@ class Script(scripts.Script): with gr.Row(): x_type = gr.Dropdown(label="X type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[1].label, type="index", elem_id=self.elem_id("x_type")) x_values = gr.Textbox(label="X values", lines=1, elem_id=self.elem_id("x_values")) - fill_x_button = ToolButton(value=fill_values_symbol, elem_id="xy_grid_fill_x_tool_button", visible=False) + fill_x_button = ToolButton(value=fill_values_symbol, elem_id="xyz_grid_fill_x_tool_button", visible=False) with gr.Row(): y_type = gr.Dropdown(label="Y type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[0].label, type="index", elem_id=self.elem_id("y_type")) y_values = gr.Textbox(label="Y values", lines=1, elem_id=self.elem_id("y_values")) - fill_y_button = ToolButton(value=fill_values_symbol, elem_id="xy_grid_fill_y_tool_button", visible=False) + fill_y_button = ToolButton(value=fill_values_symbol, elem_id="xyzz_grid_fill_y_tool_button", visible=False) + + with gr.Row(): + z_type = gr.Dropdown(label="Z type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[0].label, type="index", elem_id=self.elem_id("z_type")) + z_values = gr.Textbox(label="Z values", lines=1, elem_id=self.elem_id("z_values")) + fill_z_button = ToolButton(value=fill_values_symbol, elem_id="xyz_grid_fill_z_tool_button", visible=False) with gr.Row(variant="compact", elem_id="axis_options"): draw_legend = gr.Checkbox(label='Draw legend', value=True, elem_id=self.elem_id("draw_legend")) include_lone_images = gr.Checkbox(label='Include Separate Images', value=False, elem_id=self.elem_id("include_lone_images")) no_fixed_seeds = gr.Checkbox(label='Keep -1 for seeds', value=False, elem_id=self.elem_id("no_fixed_seeds")) - swap_axes_button = gr.Button(value="Swap axes", elem_id="xy_grid_swap_axes_button") + swap_xy_axes_button = gr.Button(value="Swap X/Y axes", elem_id="xy_grid_swap_axes_button") + swap_yz_axes_button = gr.Button(value="Swap Y/Z axes", elem_id="yz_grid_swap_axes_button") + swap_xz_axes_button = gr.Button(value="Swap X/Z axes", elem_id="xz_grid_swap_axes_button") - def swap_axes(x_type, x_values, y_type, y_values): - return self.current_axis_options[y_type].label, y_values, self.current_axis_options[x_type].label, x_values + def swap_axes(axis1_type, axis1_values, axis2_type, axis2_values): + return self.current_axis_options[axis2_type].label, axis2_values, self.current_axis_options[axis1_type].label, axis1_values - swap_args = [x_type, x_values, y_type, y_values] - swap_axes_button.click(swap_axes, inputs=swap_args, outputs=swap_args) + xy_swap_args = [x_type, x_values, y_type, y_values] + swap_xy_axes_button.click(swap_axes, inputs=xy_swap_args, outputs=xy_swap_args) + yz_swap_args = [y_type, y_values, z_type, z_values] + swap_yz_axes_button.click(swap_axes, inputs=yz_swap_args, outputs=yz_swap_args) + xz_swap_args = [x_type, x_values, z_type, z_values] + swap_xz_axes_button.click(swap_axes, inputs=xz_swap_args, outputs=xz_swap_args) def fill(x_type): axis = self.current_axis_options[x_type] @@ -326,16 +348,18 @@ class Script(scripts.Script): fill_x_button.click(fn=fill, inputs=[x_type], outputs=[x_values]) fill_y_button.click(fn=fill, inputs=[y_type], outputs=[y_values]) + fill_z_button.click(fn=fill, inputs=[z_type], outputs=[z_values]) def select_axis(x_type): return gr.Button.update(visible=self.current_axis_options[x_type].choices is not None) x_type.change(fn=select_axis, inputs=[x_type], outputs=[fill_x_button]) y_type.change(fn=select_axis, inputs=[y_type], outputs=[fill_y_button]) + z_type.change(fn=select_axis, inputs=[z_type], outputs=[fill_z_button]) - return [x_type, x_values, y_type, y_values, draw_legend, include_lone_images, no_fixed_seeds] + return [x_type, x_values, y_type, y_values, z_type, z_values, draw_legend, include_lone_images, no_fixed_seeds] - def run(self, p, x_type, x_values, y_type, y_values, draw_legend, include_lone_images, no_fixed_seeds): + def run(self, p, x_type, x_values, y_type, y_values, z_type, z_values, draw_legend, include_lone_images, no_fixed_seeds): if not no_fixed_seeds: modules.processing.fix_seed(p) @@ -409,6 +433,9 @@ class Script(scripts.Script): y_opt = self.current_axis_options[y_type] ys = process_axis(y_opt, y_values) + z_opt = self.current_axis_options[z_type] + zs = process_axis(z_opt, z_values) + def fix_axis_seeds(axis_opt, axis_list): if axis_opt.label in ['Seed', 'Var. seed']: return [int(random.randrange(4294967294)) if val is None or val == '' or val == -1 else val for val in axis_list] @@ -418,21 +445,26 @@ class Script(scripts.Script): if not no_fixed_seeds: xs = fix_axis_seeds(x_opt, xs) ys = fix_axis_seeds(y_opt, ys) + zs = fix_axis_seeds(z_opt, zs) if x_opt.label == 'Steps': - total_steps = sum(xs) * len(ys) + total_steps = sum(xs) * len(ys) * len(zs) elif y_opt.label == 'Steps': - total_steps = sum(ys) * len(xs) + total_steps = sum(ys) * len(xs) * len(zs) + elif z_opt.label == 'Steps': + total_steps = sum(zs) * len(xs) * len(ys) else: - total_steps = p.steps * len(xs) * len(ys) + total_steps = p.steps * len(xs) * len(ys) * len(zs) if isinstance(p, StableDiffusionProcessingTxt2Img) and p.enable_hr: if x_opt.label == "Hires steps": - total_steps += sum(xs) * len(ys) + total_steps += sum(xs) * len(ys) * len(zs) elif y_opt.label == "Hires steps": - total_steps += sum(ys) * len(xs) + total_steps += sum(ys) * len(xs) * len(zs) + elif z_opt.label == "Hires steps": + total_steps += sum(zs) * len(xs) * len(ys) elif p.hr_second_pass_steps: - total_steps += p.hr_second_pass_steps * len(xs) * len(ys) + total_steps += p.hr_second_pass_steps * len(xs) * len(ys) * len(zs) else: total_steps *= 2 @@ -440,7 +472,8 @@ class Script(scripts.Script): image_cell_count = p.n_iter * p.batch_size cell_console_text = f"; {image_cell_count} images per cell" if image_cell_count > 1 else "" - print(f"X/Y plot will create {len(xs) * len(ys) * image_cell_count} images on a {len(xs)}x{len(ys)} grid{cell_console_text}. (Total steps to process: {total_steps})") + plural_s = 's' if len(zs) > 1 else '' + print(f"X/Y plot will create {len(xs) * len(ys) * len(zs) * image_cell_count} images on {len(zs)} {len(xs)}x{len(ys)} grid{plural_s}{cell_console_text}. (Total steps to process: {total_steps})") shared.total_tqdm.updateTotal(total_steps) grid_infotext = [None] @@ -450,13 +483,14 @@ class Script(scripts.Script): # `for` loop. swap_axes_processing_order = x_opt.cost > y_opt.cost - def cell(x, y): + def cell(x, y, z): if shared.state.interrupted: return Processed(p, [], p.seed, "") pc = copy(p) x_opt.apply(pc, x, xs) y_opt.apply(pc, y, ys) + z_opt.apply(pc, z, zs) res = process_images(pc) @@ -475,17 +509,25 @@ class Script(scripts.Script): if y_opt.label in ["Seed", "Var. seed"] and not no_fixed_seeds: pc.extra_generation_params["Fixed Y Values"] = ", ".join([str(y) for y in ys]) + if z_opt.label != 'Nothing': + pc.extra_generation_params["Z Type"] = z_opt.label + pc.extra_generation_params["Z Values"] = z_values + if z_opt.label in ["Seed", "Var. seed"] and not no_fixed_seeds: + pc.extra_generation_params["Fixed Z Values"] = ", ".join([str(z) for z in zs]) + grid_infotext[0] = processing.create_infotext(pc, pc.all_prompts, pc.all_seeds, pc.all_subseeds) return res with SharedSettingsStackHelper(): - processed = draw_xy_grid( + processed = draw_xyz_grid( p, xs=xs, ys=ys, + zs=zs, x_labels=[x_opt.format_value(p, x_opt, x) for x in xs], y_labels=[y_opt.format_value(p, y_opt, y) for y in ys], + z_labels=[y_opt.format_value(p, z_opt, z) for z in zs], cell=cell, draw_legend=draw_legend, include_lone_images=include_lone_images, @@ -493,6 +535,6 @@ class Script(scripts.Script): ) if opts.grid_save: - images.save_image(processed.images[0], p.outpath_grids, "xy_grid", info=grid_infotext[0], extension=opts.grid_format, prompt=p.prompt, seed=processed.seed, grid=True, p=p) + images.save_image(processed.images[0], p.outpath_grids, "xyz_grid", info=grid_infotext[0], extension=opts.grid_format, prompt=p.prompt, seed=processed.seed, grid=True, p=p) return processed From e46bfa5a9e9b489ae925a9c23880e34fe8d9fffa Mon Sep 17 00:00:00 2001 From: EllangoK Date: Tue, 24 Jan 2023 02:24:32 -0500 Subject: [PATCH 040/151] handling sub grids and merging into one --- modules/images.py | 2 +- scripts/xyz_grid.py | 29 ++++++++++++++++++----------- 2 files changed, 19 insertions(+), 12 deletions(-) diff --git a/modules/images.py b/modules/images.py index 3b1c5f34..0bc3d524 100644 --- a/modules/images.py +++ b/modules/images.py @@ -195,7 +195,7 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts): ver_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing * len(lines) for lines in ver_texts] - pad_top = max(hor_text_heights) + line_spacing * 2 + pad_top = 0 if sum(hor_text_heights) == 0 else max(hor_text_heights) + line_spacing * 2 result = Image.new("RGB", (im.width + pad_left, im.height + pad_top), "white") result.paste(im, (pad_left, pad_top)) diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py index 494e8417..a16653da 100644 --- a/scripts/xyz_grid.py +++ b/scripts/xyz_grid.py @@ -205,9 +205,9 @@ axis_options = [ ] -def draw_xyz_grid(p, xs, ys, zs, x_labels, y_labels, z_labels, cell, draw_legend, include_lone_images, swap_axes_processing_order): - ver_texts = [[images.GridAnnotation(y)] for y in y_labels] +def draw_xyz_grid(p, xs, ys, zs, x_labels, y_labels, z_labels, cell, draw_legend, include_lone_images, include_sub_grids, swap_axes_processing_order): hor_texts = [[images.GridAnnotation(x)] for x in x_labels] + ver_texts = [[images.GridAnnotation(y)] for y in y_labels] title_texts = [[images.GridAnnotation(z)] for z in z_labels] # Temporary list of all the images that are generated to be populated into the grid. @@ -266,16 +266,21 @@ def draw_xyz_grid(p, xs, ys, zs, x_labels, y_labels, z_labels, cell, draw_legend print("Unexpected error: draw_xyz_grid failed to return even a single processed image") return Processed(p, []) - for i, title_text in enumerate(title_texts): + grids = [None] * len(zs) + for i in range(len(zs)): start_index = i * len(xs) * len(ys) end_index = start_index + len(xs) * len(ys) grid = images.image_grid(image_cache[start_index:end_index], rows=len(ys)) if draw_legend: grid = images.draw_grid_annotations(grid, cell_size[0], cell_size[1], hor_texts, ver_texts) - if i == 0: # First position is a placeholder as mentioned above, so it can be directly replaced - processed_result.images[0] = grid - else: - processed_result.images.insert(i, grid) + + grids[i] = grid + if include_sub_grids and len(zs) > 1: + processed_result.images.insert(i+1, grid) + + original_grid_size = grids[0].size + grids = images.image_grid(grids, rows=1) + processed_result.images[0] = images.draw_grid_annotations(grids, original_grid_size[0], original_grid_size[1], title_texts, [[images.GridAnnotation()]]) return processed_result @@ -326,7 +331,8 @@ class Script(scripts.Script): with gr.Row(variant="compact", elem_id="axis_options"): draw_legend = gr.Checkbox(label='Draw legend', value=True, elem_id=self.elem_id("draw_legend")) - include_lone_images = gr.Checkbox(label='Include Separate Images', value=False, elem_id=self.elem_id("include_lone_images")) + include_lone_images = gr.Checkbox(label='Include Sub Images', value=False, elem_id=self.elem_id("include_lone_images")) + include_sub_grids = gr.Checkbox(label='Include Sub Grids', value=False, elem_id=self.elem_id("include_sub_grids")) no_fixed_seeds = gr.Checkbox(label='Keep -1 for seeds', value=False, elem_id=self.elem_id("no_fixed_seeds")) swap_xy_axes_button = gr.Button(value="Swap X/Y axes", elem_id="xy_grid_swap_axes_button") swap_yz_axes_button = gr.Button(value="Swap Y/Z axes", elem_id="yz_grid_swap_axes_button") @@ -357,9 +363,9 @@ class Script(scripts.Script): y_type.change(fn=select_axis, inputs=[y_type], outputs=[fill_y_button]) z_type.change(fn=select_axis, inputs=[z_type], outputs=[fill_z_button]) - return [x_type, x_values, y_type, y_values, z_type, z_values, draw_legend, include_lone_images, no_fixed_seeds] + return [x_type, x_values, y_type, y_values, z_type, z_values, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds] - def run(self, p, x_type, x_values, y_type, y_values, z_type, z_values, draw_legend, include_lone_images, no_fixed_seeds): + def run(self, p, x_type, x_values, y_type, y_values, z_type, z_values, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds): if not no_fixed_seeds: modules.processing.fix_seed(p) @@ -527,10 +533,11 @@ class Script(scripts.Script): zs=zs, x_labels=[x_opt.format_value(p, x_opt, x) for x in xs], y_labels=[y_opt.format_value(p, y_opt, y) for y in ys], - z_labels=[y_opt.format_value(p, z_opt, z) for z in zs], + z_labels=[z_opt.format_value(p, z_opt, z) for z in zs], cell=cell, draw_legend=draw_legend, include_lone_images=include_lone_images, + include_sub_grids=include_sub_grids, swap_axes_processing_order=swap_axes_processing_order ) From ec8774729e17f87a8ffa5a3c5328d12834cbb02a Mon Sep 17 00:00:00 2001 From: EllangoK Date: Tue, 24 Jan 2023 02:53:35 -0500 Subject: [PATCH 041/151] swaps xyz axes internally if one costs more --- scripts/xyz_grid.py | 66 ++++++++++++++++++++++++++++++++++++--------- 1 file changed, 53 insertions(+), 13 deletions(-) diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py index a16653da..828c2d12 100644 --- a/scripts/xyz_grid.py +++ b/scripts/xyz_grid.py @@ -205,7 +205,7 @@ axis_options = [ ] -def draw_xyz_grid(p, xs, ys, zs, x_labels, y_labels, z_labels, cell, draw_legend, include_lone_images, include_sub_grids, swap_axes_processing_order): +def draw_xyz_grid(p, xs, ys, zs, x_labels, y_labels, z_labels, cell, draw_legend, include_lone_images, include_sub_grids, first_axes_processed, second_axes_processed): hor_texts = [[images.GridAnnotation(x)] for x in x_labels] ver_texts = [[images.GridAnnotation(y)] for y in y_labels] title_texts = [[images.GridAnnotation(z)] for z in z_labels] @@ -224,7 +224,7 @@ def draw_xyz_grid(p, xs, ys, zs, x_labels, y_labels, z_labels, cell, draw_legend nonlocal image_cache, processed_result, cell_mode, cell_size def index(ix, iy, iz): - return ix + iy*len(xs) + iz*len(xs)*len(ys) + return ix + iy * len(xs) + iz * len(xs) * len(ys) state.job = f"{index(ix, iy, iz) + 1} out of {len(xs) * len(ys) * len(zs)}" @@ -251,16 +251,36 @@ def draw_xyz_grid(p, xs, ys, zs, x_labels, y_labels, z_labels, cell, draw_legend except: image_cache[index(ix, iy, iz)] = Image.new(cell_mode, cell_size) - if swap_axes_processing_order: + if first_axes_processed == 'x': for ix, x in enumerate(xs): - for iy, y in enumerate(ys): - for iy, y in enumerate(zs): - process_cell(x, y, z, ix, iy, iz) - else: - for iy, y in enumerate(ys): - for ix, x in enumerate(xs): + if second_axes_processed == 'y': + for iy, y in enumerate(ys): + for iz, z in enumerate(zs): + process_cell(x, y, z, ix, iy, iz) + else: for iz, z in enumerate(zs): - process_cell(x, y, z, ix, iy, iz) + for iy, y in enumerate(ys): + process_cell(x, y, z, ix, iy, iz) + elif first_axes_processed == 'y': + for iy, y in enumerate(ys): + if second_axes_processed == 'x': + for ix, x in enumerate(xs): + for iz, z in enumerate(zs): + process_cell(x, y, z, ix, iy, iz) + else: + for iz, z in enumerate(zs): + for ix, x in enumerate(xs): + process_cell(x, y, z, ix, iy, iz) + elif first_axes_processed == 'z': + for iz, z in enumerate(zs): + if second_axes_processed == 'x': + for ix, x in enumerate(xs): + for iy, y in enumerate(ys): + process_cell(x, y, z, ix, iy, iz) + else: + for iy, y in enumerate(ys): + for ix, x in enumerate(xs): + process_cell(x, y, z, ix, iy, iz) if not processed_result: print("Unexpected error: draw_xyz_grid failed to return even a single processed image") @@ -322,7 +342,7 @@ class Script(scripts.Script): with gr.Row(): y_type = gr.Dropdown(label="Y type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[0].label, type="index", elem_id=self.elem_id("y_type")) y_values = gr.Textbox(label="Y values", lines=1, elem_id=self.elem_id("y_values")) - fill_y_button = ToolButton(value=fill_values_symbol, elem_id="xyzz_grid_fill_y_tool_button", visible=False) + fill_y_button = ToolButton(value=fill_values_symbol, elem_id="xyz_grid_fill_y_tool_button", visible=False) with gr.Row(): z_type = gr.Dropdown(label="Z type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[0].label, type="index", elem_id=self.elem_id("z_type")) @@ -487,7 +507,26 @@ class Script(scripts.Script): # If one of the axes is very slow to change between (like SD model # checkpoint), then make sure it is in the outer iteration of the nested # `for` loop. - swap_axes_processing_order = x_opt.cost > y_opt.cost + first_axes_processed = 'x' + second_axes_processed = 'y' + if x_opt.cost > y_opt.cost and x_opt.cost > z_opt.cost: + first_axes_processed = 'x' + if y_opt.cost > z_opt.cost: + second_axes_processed = 'y' + else: + second_axes_processed = 'z' + elif y_opt.cost > x_opt.cost and y_opt.cost > z_opt.cost: + first_axes_processed = 'y' + if x_opt.cost > z_opt.cost: + second_axes_processed = 'x' + else: + second_axes_processed = 'z' + elif z_opt.cost > x_opt.cost and z_opt.cost > y_opt.cost: + first_axes_processed = 'z' + if x_opt.cost > y_opt.cost: + second_axes_processed = 'x' + else: + second_axes_processed = 'y' def cell(x, y, z): if shared.state.interrupted: @@ -538,7 +577,8 @@ class Script(scripts.Script): draw_legend=draw_legend, include_lone_images=include_lone_images, include_sub_grids=include_sub_grids, - swap_axes_processing_order=swap_axes_processing_order + first_axes_processed=first_axes_processed, + second_axes_processed=second_axes_processed ) if opts.grid_save: From dac45299dd57c6cb240424b93fd28a085605bd90 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 24 Jan 2023 20:22:19 +0300 Subject: [PATCH 042/151] make git commands not fail for extensions when you have spaces in webui directory --- launch.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/launch.py b/launch.py index 82094fa0..6d523a34 100644 --- a/launch.py +++ b/launch.py @@ -108,18 +108,18 @@ def git_clone(url, dir, name, commithash=None): if commithash is None: return - current_hash = run(f'"{git}" -C {dir} rev-parse HEAD', None, f"Couldn't determine {name}'s hash: {commithash}").strip() + current_hash = run(f'"{git}" -C "{dir}" rev-parse HEAD', None, f"Couldn't determine {name}'s hash: {commithash}").strip() if current_hash == commithash: return - run(f'"{git}" -C {dir} fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}") - run(f'"{git}" -C {dir} checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}") + run(f'"{git}" -C "{dir}" fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}") + run(f'"{git}" -C "{dir}" checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}") return run(f'"{git}" clone "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}") if commithash is not None: - run(f'"{git}" -C {dir} checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}") + run(f'"{git}" -C "{dir}" checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}") def version_check(commit): From 28189985e6f56dc725938a3f0e4d2462dad74bc5 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 24 Jan 2023 20:24:27 +0300 Subject: [PATCH 043/151] remove fairscale requirement, add fake fairscale to make BLIP not complain about it --- modules/interrogate.py | 11 +++++++++-- requirements.txt | 1 - requirements_versions.txt | 1 - 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/modules/interrogate.py b/modules/interrogate.py index 236e6983..9f063197 100644 --- a/modules/interrogate.py +++ b/modules/interrogate.py @@ -82,9 +82,16 @@ class InterrogateModels: return self.loaded_categories + def create_fake_fairscale(self): + class FakeFairscale: + def checkpoint_wrapper(self): + pass + + sys.modules["fairscale.nn.checkpoint.checkpoint_activations"] = FakeFairscale + def load_blip_model(self): - with paths.Prioritize("BLIP"): - import models.blip + create_fake_fairscale() + import models.blip files = modelloader.load_models( model_path=os.path.join(paths.models_path, "BLIP"), diff --git a/requirements.txt b/requirements.txt index ef5e3472..a4be1ec3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,6 @@ blendmodes accelerate basicsr -fairscale==0.4.4 fonts font-roboto gfpgan diff --git a/requirements_versions.txt b/requirements_versions.txt index f97ad765..135908be 100644 --- a/requirements_versions.txt +++ b/requirements_versions.txt @@ -14,7 +14,6 @@ scikit-image==0.19.2 fonts font-roboto timm==0.6.7 -fairscale==0.4.9 piexif==1.1.3 einops==0.4.1 jsonmerge==1.8.0 From 5228ec8bdada50a8d614573e980193ca89192361 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 24 Jan 2023 20:30:43 +0300 Subject: [PATCH 044/151] remove fairscale requirement, add fake fairscale to make BLIP not complain about it mk2 --- modules/interrogate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/interrogate.py b/modules/interrogate.py index 9f063197..c72ff694 100644 --- a/modules/interrogate.py +++ b/modules/interrogate.py @@ -90,7 +90,7 @@ class InterrogateModels: sys.modules["fairscale.nn.checkpoint.checkpoint_activations"] = FakeFairscale def load_blip_model(self): - create_fake_fairscale() + self.create_fake_fairscale() import models.blip files = modelloader.load_models( From 93fad28a979727f9b1331dbdc447598824057cdc Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 24 Jan 2023 21:13:05 +0300 Subject: [PATCH 045/151] print progress when installing torch add PIP_INSTALLER_LOCATION env var to install pip if it's not installed remove accidental call to accelerate when venv is disabled add another env var to skip venv - SKIP_VENV --- launch.py | 20 +++++++++++++++++--- webui.bat | 8 +++++--- 2 files changed, 22 insertions(+), 6 deletions(-) diff --git a/launch.py b/launch.py index 6d523a34..f578c1c7 100644 --- a/launch.py +++ b/launch.py @@ -48,10 +48,19 @@ def extract_opt(args, name): return args, is_present, opt -def run(command, desc=None, errdesc=None, custom_env=None): +def run(command, desc=None, errdesc=None, custom_env=None, live=False): if desc is not None: print(desc) + if live: + result = subprocess.run(command, shell=True, env=os.environ if custom_env is None else custom_env) + if result.returncode != 0: + raise RuntimeError(f"""{errdesc or 'Error running command'}. +Command: {command} +Error code: {result.returncode}""") + + return "" + result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, env=os.environ if custom_env is None else custom_env) if result.returncode != 0: @@ -179,6 +188,8 @@ def run_extensions_installers(settings_file): def prepare_environment(): global skip_install + pip_installer_location = os.environ.get('PIP_INSTALLER_LOCATION', None) + torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 --extra-index-url https://download.pytorch.org/whl/cu117") requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt") commandline_args = os.environ.get('COMMANDLINE_ARGS', "") @@ -219,9 +230,12 @@ def prepare_environment(): print(f"Python {sys.version}") print(f"Commit hash: {commit}") - + + if pip_installer_location is not None and not is_installed("pip"): + run(f'"{python}" "{pip_installer_location}"', "Installing pip", "Couldn't install pip") + if reinstall_torch or not is_installed("torch") or not is_installed("torchvision"): - run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch") + run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch", live=True) if not skip_torch_cuda_test: run_python("import torch; assert torch.cuda.is_available(), 'Torch is not able to use GPU; add --skip-torch-cuda-test to COMMANDLINE_ARGS variable to disable this check'") diff --git a/webui.bat b/webui.bat index 3165b94d..0d6865c9 100644 --- a/webui.bat +++ b/webui.bat @@ -3,6 +3,7 @@ if not defined PYTHON (set PYTHON=python) if not defined VENV_DIR (set "VENV_DIR=%~dp0%venv") + set ERROR_REPORTING=FALSE mkdir tmp 2>NUL @@ -14,6 +15,7 @@ goto :show_stdout_stderr :start_venv if ["%VENV_DIR%"] == ["-"] goto :skip_venv +if ["%SKIP_VENV%"] == ["1"] goto :skip_venv dir "%VENV_DIR%\Scripts\Python.exe" >tmp/stdout.txt 2>tmp/stderr.txt if %ERRORLEVEL% == 0 goto :activate_venv @@ -28,13 +30,13 @@ goto :show_stdout_stderr :activate_venv set PYTHON="%VENV_DIR%\Scripts\Python.exe" echo venv %PYTHON% + +:skip_venv if [%ACCELERATE%] == ["True"] goto :accelerate goto :launch -:skip_venv - :accelerate -echo "Checking for accelerate" +echo Checking for accelerate set ACCELERATE="%VENV_DIR%\Scripts\accelerate.exe" if EXIST %ACCELERATE% goto :accelerate_launch From bef193189500884c2b20605290ac8bef8251a788 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 24 Jan 2023 23:50:04 +0300 Subject: [PATCH 046/151] add fastapi to requirements --- requirements_versions.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements_versions.txt b/requirements_versions.txt index 135908be..1c328d44 100644 --- a/requirements_versions.txt +++ b/requirements_versions.txt @@ -4,6 +4,7 @@ accelerate==0.12.0 basicsr==1.4.2 gfpgan==1.3.8 gradio==3.16.2 +fastapi==0.82.0 numpy==1.23.3 Pillow==9.4.0 realesrgan==0.3.0 From 48a15821de768fea76e66f26df83df3fddf18f4b Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 25 Jan 2023 00:49:16 +0300 Subject: [PATCH 047/151] remove the pip install stuff because it does not work as i hoped it would --- launch.py | 5 ----- requirements_versions.txt | 1 - webui.bat | 13 +++++++++++-- 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/launch.py b/launch.py index f578c1c7..9d6f4a8c 100644 --- a/launch.py +++ b/launch.py @@ -188,8 +188,6 @@ def run_extensions_installers(settings_file): def prepare_environment(): global skip_install - pip_installer_location = os.environ.get('PIP_INSTALLER_LOCATION', None) - torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 --extra-index-url https://download.pytorch.org/whl/cu117") requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt") commandline_args = os.environ.get('COMMANDLINE_ARGS', "") @@ -231,9 +229,6 @@ def prepare_environment(): print(f"Python {sys.version}") print(f"Commit hash: {commit}") - if pip_installer_location is not None and not is_installed("pip"): - run(f'"{python}" "{pip_installer_location}"', "Installing pip", "Couldn't install pip") - if reinstall_torch or not is_installed("torch") or not is_installed("torchvision"): run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch", live=True) diff --git a/requirements_versions.txt b/requirements_versions.txt index 1c328d44..135908be 100644 --- a/requirements_versions.txt +++ b/requirements_versions.txt @@ -4,7 +4,6 @@ accelerate==0.12.0 basicsr==1.4.2 gfpgan==1.3.8 gradio==3.16.2 -fastapi==0.82.0 numpy==1.23.3 Pillow==9.4.0 realesrgan==0.3.0 diff --git a/webui.bat b/webui.bat index 0d6865c9..209d972b 100644 --- a/webui.bat +++ b/webui.bat @@ -9,10 +9,19 @@ set ERROR_REPORTING=FALSE mkdir tmp 2>NUL %PYTHON% -c "" >tmp/stdout.txt 2>tmp/stderr.txt -if %ERRORLEVEL% == 0 goto :start_venv +if %ERRORLEVEL% == 0 goto :check_pip echo Couldn't launch python goto :show_stdout_stderr +:check_pip +%PYTHON% -mpip --help >tmp/stdout.txt 2>tmp/stderr.txt +if %ERRORLEVEL% == 0 goto :start_venv +if "%PIP_INSTALLER_LOCATION%" == "" goto :show_stdout_stderr +%PYTHON% "%PIP_INSTALLER_LOCATION%" >tmp/stdout.txt 2>tmp/stderr.txt +if %ERRORLEVEL% == 0 goto :start_venv +echo Couldn't install pip +goto :show_stdout_stderr + :start_venv if ["%VENV_DIR%"] == ["-"] goto :skip_venv if ["%SKIP_VENV%"] == ["1"] goto :skip_venv @@ -46,7 +55,7 @@ pause exit /b :accelerate_launch -echo "Accelerating" +echo Accelerating %ACCELERATE% launch --num_cpu_threads_per_process=6 launch.py pause exit /b From 84d9ce30cb427759547bc7876ed80ab91787d175 Mon Sep 17 00:00:00 2001 From: brkirch Date: Tue, 24 Jan 2023 23:51:45 -0500 Subject: [PATCH 048/151] Add option for float32 sampling with float16 UNet This also handles type casting so that ROCm and MPS torch devices work correctly without --no-half. One cast is required for deepbooru in deepbooru_model.py, some explicit casting is required for img2img and inpainting. depth_model can't be converted to float16 or it won't work correctly on some systems (it's known to have issues on MPS) so in sd_models.py model.depth_model is removed for model.half(). --- README.md | 1 + modules/deepbooru_model.py | 4 +++- modules/devices.py | 2 ++ modules/processing.py | 15 ++++++++------- modules/sd_hijack_unet.py | 29 +++++++++++++++++++++++++++++ modules/sd_hijack_utils.py | 28 ++++++++++++++++++++++++++++ modules/sd_models.py | 10 ++++++++++ modules/shared.py | 1 + 8 files changed, 82 insertions(+), 8 deletions(-) create mode 100644 modules/sd_hijack_utils.py diff --git a/README.md b/README.md index 9c0cd1ef..a5611671 100644 --- a/README.md +++ b/README.md @@ -157,4 +157,5 @@ Licenses for borrowed code can be found in `Settings -> Licenses` screen, and al - DeepDanbooru - interrogator for anime diffusers https://github.com/KichangKim/DeepDanbooru - Security advice - RyotaK - Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user. +- Sampling in float32 precision from a float16 UNet - marunine for the idea, Birch-san for the example Diffusers implementation (https://github.com/Birch-san/diffusers-play/tree/92feee6) - (You) diff --git a/modules/deepbooru_model.py b/modules/deepbooru_model.py index edd40c81..83d2ff09 100644 --- a/modules/deepbooru_model.py +++ b/modules/deepbooru_model.py @@ -2,6 +2,8 @@ import torch import torch.nn as nn import torch.nn.functional as F +from modules import devices + # see https://github.com/AUTOMATIC1111/TorchDeepDanbooru for more @@ -196,7 +198,7 @@ class DeepDanbooruModel(nn.Module): t_358, = inputs t_359 = t_358.permute(*[0, 3, 1, 2]) t_359_padded = F.pad(t_359, [2, 3, 2, 3], value=0) - t_360 = self.n_Conv_0(t_359_padded) + t_360 = self.n_Conv_0(t_359_padded.to(self.n_Conv_0.bias.dtype) if devices.unet_needs_upcast else t_359_padded) t_361 = F.relu(t_360) t_361 = F.pad(t_361, [0, 1, 0, 1], value=float('-inf')) t_362 = self.n_MaxPool_0(t_361) diff --git a/modules/devices.py b/modules/devices.py index 524ec7af..0981ef80 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -79,6 +79,8 @@ cpu = torch.device("cpu") device = device_interrogate = device_gfpgan = device_esrgan = device_codeformer = None dtype = torch.float16 dtype_vae = torch.float16 +dtype_unet = torch.float16 +unet_needs_upcast = False def randn(seed, shape): diff --git a/modules/processing.py b/modules/processing.py index bc541e2f..2d186ba0 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -172,7 +172,8 @@ class StableDiffusionProcessing: midas_in = torch.from_numpy(transformed["midas_in"][None, ...]).to(device=shared.device) midas_in = repeat(midas_in, "1 ... -> n ...", n=self.batch_size) - conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image)) + conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image.to(devices.dtype_unet) if devices.unet_needs_upcast else source_image)) + conditioning_image = conditioning_image.float() if devices.unet_needs_upcast else conditioning_image conditioning = torch.nn.functional.interpolate( self.sd_model.depth_model(midas_in), size=conditioning_image.shape[2:], @@ -203,7 +204,7 @@ class StableDiffusionProcessing: # Create another latent image, this time with a masked version of the original input. # Smoothly interpolate between the masked and unmasked latent conditioning image using a parameter. - conditioning_mask = conditioning_mask.to(source_image.device).to(source_image.dtype) + conditioning_mask = conditioning_mask.to(device=source_image.device, dtype=source_image.dtype) conditioning_image = torch.lerp( source_image, source_image * (1.0 - conditioning_mask), @@ -211,7 +212,7 @@ class StableDiffusionProcessing: ) # Encode the new masked image using first stage of network. - conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image)) + conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image.to(devices.dtype_unet) if devices.unet_needs_upcast else conditioning_image)) # Create the concatenated conditioning tensor to be fed to `c_concat` conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=latent_image.shape[-2:]) @@ -225,10 +226,10 @@ class StableDiffusionProcessing: # HACK: Using introspection as the Depth2Image model doesn't appear to uniquely # identify itself with a field common to all models. The conditioning_key is also hybrid. if isinstance(self.sd_model, LatentDepth2ImageDiffusion): - return self.depth2img_image_conditioning(source_image) + return self.depth2img_image_conditioning(source_image.float() if devices.unet_needs_upcast else source_image) if self.sampler.conditioning_key in {'hybrid', 'concat'}: - return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask) + return self.inpainting_image_conditioning(source_image.float() if devices.unet_needs_upcast else source_image, latent_image, image_mask=image_mask) # Dummy zero conditioning if we're not using inpainting or depth model. return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1) @@ -610,7 +611,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if p.n_iter > 1: shared.state.job = f"Batch {n+1} out of {p.n_iter}" - with devices.autocast(): + with devices.autocast(disable=devices.unet_needs_upcast): samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, prompts=prompts) x_samples_ddim = [decode_first_stage(p.sd_model, samples_ddim[i:i+1].to(dtype=devices.dtype_vae))[0].cpu() for i in range(samples_ddim.size(0))] @@ -988,7 +989,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): image = torch.from_numpy(batch_images) image = 2. * image - 1. - image = image.to(shared.device) + image = image.to(device=shared.device, dtype=devices.dtype_unet if devices.unet_needs_upcast else None) self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image)) diff --git a/modules/sd_hijack_unet.py b/modules/sd_hijack_unet.py index 18daf8c1..88c94e54 100644 --- a/modules/sd_hijack_unet.py +++ b/modules/sd_hijack_unet.py @@ -1,4 +1,8 @@ import torch +from packaging import version + +from modules import devices +from modules.sd_hijack_utils import CondFunc class TorchHijackForUnet: @@ -28,3 +32,28 @@ class TorchHijackForUnet: th = TorchHijackForUnet() + + +# Below are monkey patches to enable upcasting a float16 UNet for float32 sampling +def apply_model(orig_func, self, x_noisy, t, cond, **kwargs): + for y in cond.keys(): + cond[y] = [x.to(devices.dtype_unet) if isinstance(x, torch.Tensor) else x for x in cond[y]] + with devices.autocast(): + return orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs).float() + +class GELUHijack(torch.nn.GELU, torch.nn.Module): + def __init__(self, *args, **kwargs): + torch.nn.GELU.__init__(self, *args, **kwargs) + def forward(self, x): + if devices.unet_needs_upcast: + return torch.nn.GELU.forward(self.float(), x.float()).to(devices.dtype_unet) + else: + return torch.nn.GELU.forward(self, x) + +unet_needs_upcast = lambda *args, **kwargs: devices.unet_needs_upcast +CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model, unet_needs_upcast) +CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).to(devices.dtype_unet), unet_needs_upcast) +if version.parse(torch.__version__) <= version.parse("1.13.1"): + CondFunc('ldm.modules.diffusionmodules.util.GroupNorm32.forward', lambda orig_func, self, *args, **kwargs: orig_func(self.float(), *args, **kwargs), unet_needs_upcast) + CondFunc('ldm.modules.attention.GEGLU.forward', lambda orig_func, self, x: orig_func(self.float(), x.float()).to(devices.dtype_unet), unet_needs_upcast) + CondFunc('open_clip.transformer.ResidualAttentionBlock.__init__', lambda orig_func, *args, **kwargs: kwargs.update({'act_layer': GELUHijack}) and False or orig_func(*args, **kwargs), lambda _, *args, **kwargs: kwargs.get('act_layer') is None or kwargs['act_layer'] == torch.nn.GELU) diff --git a/modules/sd_hijack_utils.py b/modules/sd_hijack_utils.py new file mode 100644 index 00000000..f81b169a --- /dev/null +++ b/modules/sd_hijack_utils.py @@ -0,0 +1,28 @@ +import importlib + +class CondFunc: + def __new__(cls, orig_func, sub_func, cond_func): + self = super(CondFunc, cls).__new__(cls) + if isinstance(orig_func, str): + func_path = orig_func.split('.') + for i in range(len(func_path)-2, -1, -1): + try: + resolved_obj = importlib.import_module('.'.join(func_path[:i])) + break + except ImportError: + pass + for attr_name in func_path[i:-1]: + resolved_obj = getattr(resolved_obj, attr_name) + orig_func = getattr(resolved_obj, func_path[-1]) + setattr(resolved_obj, func_path[-1], lambda *args, **kwargs: self(*args, **kwargs)) + self.__init__(orig_func, sub_func, cond_func) + return lambda *args, **kwargs: self(*args, **kwargs) + def __init__(self, orig_func, sub_func, cond_func): + self.__orig_func = orig_func + self.__sub_func = sub_func + self.__cond_func = cond_func + def __call__(self, *args, **kwargs): + if not self.__cond_func or self.__cond_func(self.__orig_func, *args, **kwargs): + return self.__sub_func(self.__orig_func, *args, **kwargs) + else: + return self.__orig_func(*args, **kwargs) diff --git a/modules/sd_models.py b/modules/sd_models.py index 12083848..7c98991a 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -257,16 +257,24 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo): if not shared.cmd_opts.no_half: vae = model.first_stage_model + depth_model = getattr(model, 'depth_model', None) # with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16 if shared.cmd_opts.no_half_vae: model.first_stage_model = None + # with --upcast-sampling, don't convert the depth model weights to float16 + if shared.cmd_opts.upcast_sampling and depth_model: + model.depth_model = None model.half() model.first_stage_model = vae + if depth_model: + model.depth_model = depth_model devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16 devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16 + devices.dtype_unet = model.model.diffusion_model.dtype + devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16 model.first_stage_model.to(devices.dtype_vae) @@ -372,6 +380,8 @@ def load_model(checkpoint_info=None): if shared.cmd_opts.no_half: sd_config.model.params.unet_config.params.use_fp16 = False + elif shared.cmd_opts.upcast_sampling: + sd_config.model.params.unet_config.params.use_fp16 = True timer = Timer() diff --git a/modules/shared.py b/modules/shared.py index 5f713bee..4ce1209b 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -45,6 +45,7 @@ parser.add_argument("--lowram", action='store_true', help="load stable diffusion parser.add_argument("--always-batch-cond-uncond", action='store_true', help="disables cond/uncond batching that is enabled to save memory with --medvram or --lowvram") parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.") parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast") +parser.add_argument("--upcast-sampling", action='store_true', help="upcast sampling. No effect with --no-half. Usually produces similar results to --no-half with better performance while using less memory.") parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site") parser.add_argument("--ngrok", type=str, help="ngrok authtoken, alternative to gradio --share", default=None) parser.add_argument("--ngrok-region", type=str, help="The region in which ngrok should start.", default="us") From e3b53fd295aca784253dfc8668ec87b537a72f43 Mon Sep 17 00:00:00 2001 From: brkirch Date: Wed, 25 Jan 2023 00:23:10 -0500 Subject: [PATCH 049/151] Add UI setting for upcasting attention to float32 Adds "Upcast cross attention layer to float32" option in Stable Diffusion settings. This allows for generating images using SD 2.1 models without --no-half or xFormers. In order to make upcasting cross attention layer optimizations possible it is necessary to indent several sections of code in sd_hijack_optimizations.py so that a context manager can be used to disable autocast. Also, even though Stable Diffusion (and Diffusers) only upcast q and k, unfortunately my findings were that most of the cross attention layer optimizations could not function unless v is upcast also. --- modules/devices.py | 6 +- modules/processing.py | 2 +- modules/sd_hijack_optimizations.py | 159 ++++++++++++++++++----------- modules/shared.py | 1 + modules/sub_quadratic_attention.py | 4 +- 5 files changed, 108 insertions(+), 64 deletions(-) diff --git a/modules/devices.py b/modules/devices.py index 0981ef80..6b36622c 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -108,6 +108,10 @@ def autocast(disable=False): return torch.autocast("cuda") +def without_autocast(disable=False): + return torch.autocast("cuda", enabled=False) if torch.is_autocast_enabled() and not disable else contextlib.nullcontext() + + class NansException(Exception): pass @@ -125,7 +129,7 @@ def test_for_nans(x, where): message = "A tensor with all NaNs was produced in Unet." if not shared.cmd_opts.no_half: - message += " This could be either because there's not enough precision to represent the picture, or because your video card does not support half type. Try using --no-half commandline argument to fix this." + message += " This could be either because there's not enough precision to represent the picture, or because your video card does not support half type. Try setting the \"Upcast cross attention layer to float32\" option in Settings > Stable Diffusion or using the --no-half commandline argument to fix this." elif where == "vae": message = "A tensor with all NaNs was produced in VAE." diff --git a/modules/processing.py b/modules/processing.py index 2d186ba0..a850082d 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -611,7 +611,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if p.n_iter > 1: shared.state.job = f"Batch {n+1} out of {p.n_iter}" - with devices.autocast(disable=devices.unet_needs_upcast): + with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast(): samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, prompts=prompts) x_samples_ddim = [decode_first_stage(p.sd_model, samples_ddim[i:i+1].to(dtype=devices.dtype_vae))[0].cpu() for i in range(samples_ddim.size(0))] diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 74452709..c02d954c 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -9,7 +9,7 @@ from torch import einsum from ldm.util import default from einops import rearrange -from modules import shared, errors +from modules import shared, errors, devices from modules.hypernetworks import hypernetwork from .sub_quadratic_attention import efficient_dot_product_attention @@ -52,18 +52,25 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None): q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in)) del q_in, k_in, v_in - r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device) - for i in range(0, q.shape[0], 2): - end = i + 2 - s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end]) - s1 *= self.scale + dtype = q.dtype + if shared.opts.upcast_attn: + q, k, v = q.float(), k.float(), v.float() - s2 = s1.softmax(dim=-1) - del s1 + with devices.without_autocast(disable=not shared.opts.upcast_attn): + r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) + for i in range(0, q.shape[0], 2): + end = i + 2 + s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end]) + s1 *= self.scale + + s2 = s1.softmax(dim=-1) + del s1 + + r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end]) + del s2 + del q, k, v - r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end]) - del s2 - del q, k, v + r1 = r1.to(dtype) r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h) del r1 @@ -82,45 +89,52 @@ def split_cross_attention_forward(self, x, context=None, mask=None): k_in = self.to_k(context_k) v_in = self.to_v(context_v) - k_in *= self.scale + dtype = q_in.dtype + if shared.opts.upcast_attn: + q_in, k_in, v_in = q_in.float(), k_in.float(), v_in if v_in.device.type == 'mps' else v_in.float() - del context, x + with devices.without_autocast(disable=not shared.opts.upcast_attn): + k_in = k_in * self.scale + + del context, x + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in)) + del q_in, k_in, v_in + + r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) + + mem_free_total = get_available_vram() + + gb = 1024 ** 3 + tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() + modifier = 3 if q.element_size() == 2 else 2.5 + mem_required = tensor_size * modifier + steps = 1 + + if mem_required > mem_free_total: + steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2))) + # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB " + # f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}") + + if steps > 64: + max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64 + raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). ' + f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free') + + slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] + for i in range(0, q.shape[1], slice_size): + end = i + slice_size + s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) + + s2 = s1.softmax(dim=-1, dtype=q.dtype) + del s1 + + r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v) + del s2 + + del q, k, v - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in)) - del q_in, k_in, v_in - - r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) - - mem_free_total = get_available_vram() - - gb = 1024 ** 3 - tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() - modifier = 3 if q.element_size() == 2 else 2.5 - mem_required = tensor_size * modifier - steps = 1 - - if mem_required > mem_free_total: - steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2))) - # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB " - # f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}") - - if steps > 64: - max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64 - raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). ' - f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free') - - slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] - for i in range(0, q.shape[1], slice_size): - end = i + slice_size - s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) - - s2 = s1.softmax(dim=-1, dtype=q.dtype) - del s1 - - r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v) - del s2 - - del q, k, v + r1 = r1.to(dtype) r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h) del r1 @@ -204,12 +218,20 @@ def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None): context = default(context, x) context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context) - k = self.to_k(context_k) * self.scale + k = self.to_k(context_k) v = self.to_v(context_v) del context, context_k, context_v, x - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) - r = einsum_op(q, k, v) + dtype = q.dtype + if shared.opts.upcast_attn: + q, k, v = q.float(), k.float(), v if v.device.type == 'mps' else v.float() + + with devices.without_autocast(disable=not shared.opts.upcast_attn): + k = k * self.scale + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + r = einsum_op(q, k, v) + r = r.to(dtype) return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h)) # -- End of code from https://github.com/invoke-ai/InvokeAI -- @@ -234,8 +256,14 @@ def sub_quad_attention_forward(self, x, context=None, mask=None): k = k.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1) v = v.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1) + dtype = q.dtype + if shared.opts.upcast_attn: + q, k = q.float(), k.float() + x = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training) + x = x.to(dtype) + x = x.unflatten(0, (-1, h)).transpose(1,2).flatten(start_dim=2) out_proj, dropout = self.to_out @@ -268,15 +296,16 @@ def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_ query_chunk_size = q_tokens kv_chunk_size = k_tokens - return efficient_dot_product_attention( - q, - k, - v, - query_chunk_size=q_chunk_size, - kv_chunk_size=kv_chunk_size, - kv_chunk_size_min = kv_chunk_size_min, - use_checkpoint=use_checkpoint, - ) + with devices.without_autocast(disable=q.dtype == v.dtype): + return efficient_dot_product_attention( + q, + k, + v, + query_chunk_size=q_chunk_size, + kv_chunk_size=kv_chunk_size, + kv_chunk_size_min = kv_chunk_size_min, + use_checkpoint=use_checkpoint, + ) def get_xformers_flash_attention_op(q, k, v): @@ -306,8 +335,14 @@ def xformers_attention_forward(self, x, context=None, mask=None): q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in)) del q_in, k_in, v_in + dtype = q.dtype + if shared.opts.upcast_attn: + q, k = q.float(), k.float() + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=get_xformers_flash_attention_op(q, k, v)) + out = out.to(dtype) + out = rearrange(out, 'b n h d -> b n (h d)', h=h) return self.to_out(out) @@ -378,10 +413,14 @@ def xformers_attnblock_forward(self, x): v = self.v(h_) b, c, h, w = q.shape q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v)) + dtype = q.dtype + if shared.opts.upcast_attn: + q, k = q.float(), k.float() q = q.contiguous() k = k.contiguous() v = v.contiguous() out = xformers.ops.memory_efficient_attention(q, k, v, op=get_xformers_flash_attention_op(q, k, v)) + out = out.to(dtype) out = rearrange(out, 'b (h w) c -> b c h w', h=h) out = self.proj_out(out) return x + out diff --git a/modules/shared.py b/modules/shared.py index 4ce1209b..6a0b96cb 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -410,6 +410,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }), "CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}), "extra_networks_default_multiplier": OptionInfo(1.0, "Multiplier for extra networks", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), + "upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"), })) options_templates.update(options_section(('compatibility', "Compatibility"), { diff --git a/modules/sub_quadratic_attention.py b/modules/sub_quadratic_attention.py index 55052815..05595323 100644 --- a/modules/sub_quadratic_attention.py +++ b/modules/sub_quadratic_attention.py @@ -67,7 +67,7 @@ def _summarize_chunk( max_score, _ = torch.max(attn_weights, -1, keepdim=True) max_score = max_score.detach() exp_weights = torch.exp(attn_weights - max_score) - exp_values = torch.bmm(exp_weights, value) + exp_values = torch.bmm(exp_weights, value) if query.device.type == 'mps' else torch.bmm(exp_weights, value.to(exp_weights.dtype)).to(value.dtype) max_score = max_score.squeeze(-1) return AttnChunk(exp_values, exp_weights.sum(dim=-1), max_score) @@ -129,7 +129,7 @@ def _get_attention_scores_no_kv_chunking( ) attn_probs = attn_scores.softmax(dim=-1) del attn_scores - hidden_states_slice = torch.bmm(attn_probs, value) + hidden_states_slice = torch.bmm(attn_probs, value) if query.device.type == 'mps' else torch.bmm(attn_probs, value.to(attn_probs.dtype)).to(value.dtype) return hidden_states_slice From 1bfec873fa13d803f3d4ac2a12bf6983838233fe Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 25 Jan 2023 11:29:46 +0300 Subject: [PATCH 050/151] add an experimental option to apply loras to outputs rather than inputs --- extensions-builtin/Lora/lora.py | 5 ++++- extensions-builtin/Lora/scripts/lora_script.py | 7 ++++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py index 137e58f7..cb8f1d36 100644 --- a/extensions-builtin/Lora/lora.py +++ b/extensions-builtin/Lora/lora.py @@ -166,7 +166,10 @@ def lora_forward(module, input, res): for lora in loaded_loras: module = lora.modules.get(lora_layer_name, None) if module is not None: - res = res + module.up(module.down(input)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0) + if shared.opts.lora_apply_to_outputs and res.shape == input.shape: + res = res + module.up(module.down(res)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0) + else: + res = res + module.up(module.down(input)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0) return res diff --git a/extensions-builtin/Lora/scripts/lora_script.py b/extensions-builtin/Lora/scripts/lora_script.py index 60b9eb64..544b228d 100644 --- a/extensions-builtin/Lora/scripts/lora_script.py +++ b/extensions-builtin/Lora/scripts/lora_script.py @@ -3,7 +3,7 @@ import torch import lora import extra_networks_lora import ui_extra_networks_lora -from modules import script_callbacks, ui_extra_networks, extra_networks +from modules import script_callbacks, ui_extra_networks, extra_networks, shared def unload(): @@ -28,3 +28,8 @@ torch.nn.Conv2d.forward = lora.lora_Conv2d_forward script_callbacks.on_model_loaded(lora.assign_lora_names_to_compvis_modules) script_callbacks.on_script_unloaded(unload) script_callbacks.on_before_ui(before_ui) + + +shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), { + "lora_apply_to_outputs": shared.OptionInfo(False, "Apply Lora to outputs rather than inputs when possible (experimental)"), +})) From ee0a0da3244123cb6d2ba4097a54a1e9caccb687 Mon Sep 17 00:00:00 2001 From: Kyle Date: Wed, 25 Jan 2023 08:53:23 -0500 Subject: [PATCH 051/151] Add instruct-pix2pix hijack Allows loading instruct-pix2pix models via same method as inpainting models in sd_models.py and sd_hijack_ip2p.py Adds ddpm_edit.py necessary for instruct-pix2pix --- modules/models/diffusion/ddpm_edit.py | 1459 +++++++++++++++++++++++++ modules/sd_hijack_ip2p.py | 13 + modules/sd_models.py | 12 +- 3 files changed, 1483 insertions(+), 1 deletion(-) create mode 100644 modules/models/diffusion/ddpm_edit.py create mode 100644 modules/sd_hijack_ip2p.py diff --git a/modules/models/diffusion/ddpm_edit.py b/modules/models/diffusion/ddpm_edit.py new file mode 100644 index 00000000..f3d49c44 --- /dev/null +++ b/modules/models/diffusion/ddpm_edit.py @@ -0,0 +1,1459 @@ +""" +wild mixture of +https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py +https://github.com/CompVis/taming-transformers +-- merci +""" + +# File modified by authors of InstructPix2Pix from original (https://github.com/CompVis/stable-diffusion). +# See more details in LICENSE. + +import torch +import torch.nn as nn +import numpy as np +import pytorch_lightning as pl +from torch.optim.lr_scheduler import LambdaLR +from einops import rearrange, repeat +from contextlib import contextmanager +from functools import partial +from tqdm import tqdm +from torchvision.utils import make_grid +from pytorch_lightning.utilities.distributed import rank_zero_only + +from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config +from ldm.modules.ema import LitEma +from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution +from ldm.models.autoencoder import VQModelInterface, IdentityFirstStage, AutoencoderKL +from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like +from ldm.models.diffusion.ddim import DDIMSampler + + +__conditioning_keys__ = {'concat': 'c_concat', + 'crossattn': 'c_crossattn', + 'adm': 'y'} + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +def uniform_on_device(r1, r2, shape, device): + return (r1 - r2) * torch.rand(*shape, device=device) + r2 + + +class DDPM(pl.LightningModule): + # classic DDPM with Gaussian diffusion, in image space + def __init__(self, + unet_config, + timesteps=1000, + beta_schedule="linear", + loss_type="l2", + ckpt_path=None, + ignore_keys=[], + load_only_unet=False, + monitor="val/loss", + use_ema=True, + first_stage_key="image", + image_size=256, + channels=3, + log_every_t=100, + clip_denoised=True, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3, + given_betas=None, + original_elbo_weight=0., + v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta + l_simple_weight=1., + conditioning_key=None, + parameterization="eps", # all assuming fixed variance schedules + scheduler_config=None, + use_positional_encodings=False, + learn_logvar=False, + logvar_init=0., + load_ema=True, + ): + super().__init__() + assert parameterization in ["eps", "x0"], 'currently only supporting "eps" and "x0"' + self.parameterization = parameterization + print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode") + self.cond_stage_model = None + self.clip_denoised = clip_denoised + self.log_every_t = log_every_t + self.first_stage_key = first_stage_key + self.image_size = image_size # try conv? + self.channels = channels + self.use_positional_encodings = use_positional_encodings + self.model = DiffusionWrapper(unet_config, conditioning_key) + count_params(self.model, verbose=True) + self.use_ema = use_ema + + self.use_scheduler = scheduler_config is not None + if self.use_scheduler: + self.scheduler_config = scheduler_config + + self.v_posterior = v_posterior + self.original_elbo_weight = original_elbo_weight + self.l_simple_weight = l_simple_weight + + if monitor is not None: + self.monitor = monitor + + if self.use_ema and load_ema: + self.model_ema = LitEma(self.model) + print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet) + + # If initialing from EMA-only checkpoint, create EMA model after loading. + if self.use_ema and not load_ema: + self.model_ema = LitEma(self.model) + print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + + self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps, + linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s) + + self.loss_type = loss_type + + self.learn_logvar = learn_logvar + self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,)) + if self.learn_logvar: + self.logvar = nn.Parameter(self.logvar, requires_grad=True) + + + def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000, + linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + if exists(given_betas): + betas = given_betas + else: + betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, + cosine_s=cosine_s) + alphas = 1. - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) + + timesteps, = betas.shape + self.num_timesteps = int(timesteps) + self.linear_start = linear_start + self.linear_end = linear_end + assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' + + to_torch = partial(torch.tensor, dtype=torch.float32) + + self.register_buffer('betas', to_torch(betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / ( + 1. - alphas_cumprod) + self.v_posterior * betas + # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) + self.register_buffer('posterior_variance', to_torch(posterior_variance)) + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20)))) + self.register_buffer('posterior_mean_coef1', to_torch( + betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))) + self.register_buffer('posterior_mean_coef2', to_torch( + (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))) + + if self.parameterization == "eps": + lvlb_weights = self.betas ** 2 / ( + 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)) + elif self.parameterization == "x0": + lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod)) + else: + raise NotImplementedError("mu not supported") + # TODO how to choose this term + lvlb_weights[0] = lvlb_weights[1] + self.register_buffer('lvlb_weights', lvlb_weights, persistent=False) + assert not torch.isnan(self.lvlb_weights).all() + + @contextmanager + def ema_scope(self, context=None): + if self.use_ema: + self.model_ema.store(self.model.parameters()) + self.model_ema.copy_to(self.model) + if context is not None: + print(f"{context}: Switched to EMA weights") + try: + yield None + finally: + if self.use_ema: + self.model_ema.restore(self.model.parameters()) + if context is not None: + print(f"{context}: Restored training weights") + + def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): + sd = torch.load(path, map_location="cpu") + if "state_dict" in list(sd.keys()): + sd = sd["state_dict"] + keys = list(sd.keys()) + + # Our model adds additional channels to the first layer to condition on an input image. + # For the first layer, copy existing channel weights and initialize new channel weights to zero. + input_keys = [ + "model.diffusion_model.input_blocks.0.0.weight", + "model_ema.diffusion_modelinput_blocks00weight", + ] + + self_sd = self.state_dict() + for input_key in input_keys: + if input_key not in sd or input_key not in self_sd: + continue + + input_weight = self_sd[input_key] + + if input_weight.size() != sd[input_key].size(): + print(f"Manual init: {input_key}") + input_weight.zero_() + input_weight[:, :4, :, :].copy_(sd[input_key]) + ignore_keys.append(input_key) + + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( + sd, strict=False) + print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") + if len(missing) > 0: + print(f"Missing Keys: {missing}") + if len(unexpected) > 0: + print(f"Unexpected Keys: {unexpected}") + + def q_mean_variance(self, x_start, t): + """ + Get the distribution q(x_t | x_0). + :param x_start: the [N x C x ...] tensor of noiseless inputs. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :return: A tuple (mean, variance, log_variance), all of x_start's shape. + """ + mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start) + variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) + log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape) + return mean, variance, log_variance + + def predict_start_from_noise(self, x_t, t, noise): + return ( + extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - + extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise + ) + + def q_posterior(self, x_start, x_t, t): + posterior_mean = ( + extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + + extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance(self, x, t, clip_denoised: bool): + model_out = self.model(x, t) + if self.parameterization == "eps": + x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) + elif self.parameterization == "x0": + x_recon = model_out + if clip_denoised: + x_recon.clamp_(-1., 1.) + + model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) + return model_mean, posterior_variance, posterior_log_variance + + @torch.no_grad() + def p_sample(self, x, t, clip_denoised=True, repeat_noise=False): + b, *_, device = *x.shape, x.device + model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised) + noise = noise_like(x.shape, device, repeat_noise) + # no noise when t == 0 + nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + + @torch.no_grad() + def p_sample_loop(self, shape, return_intermediates=False): + device = self.betas.device + b = shape[0] + img = torch.randn(shape, device=device) + intermediates = [img] + for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps): + img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long), + clip_denoised=self.clip_denoised) + if i % self.log_every_t == 0 or i == self.num_timesteps - 1: + intermediates.append(img) + if return_intermediates: + return img, intermediates + return img + + @torch.no_grad() + def sample(self, batch_size=16, return_intermediates=False): + image_size = self.image_size + channels = self.channels + return self.p_sample_loop((batch_size, channels, image_size, image_size), + return_intermediates=return_intermediates) + + def q_sample(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) + + def get_loss(self, pred, target, mean=True): + if self.loss_type == 'l1': + loss = (target - pred).abs() + if mean: + loss = loss.mean() + elif self.loss_type == 'l2': + if mean: + loss = torch.nn.functional.mse_loss(target, pred) + else: + loss = torch.nn.functional.mse_loss(target, pred, reduction='none') + else: + raise NotImplementedError("unknown loss type '{loss_type}'") + + return loss + + def p_losses(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + model_out = self.model(x_noisy, t) + + loss_dict = {} + if self.parameterization == "eps": + target = noise + elif self.parameterization == "x0": + target = x_start + else: + raise NotImplementedError(f"Paramterization {self.parameterization} not yet supported") + + loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3]) + + log_prefix = 'train' if self.training else 'val' + + loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()}) + loss_simple = loss.mean() * self.l_simple_weight + + loss_vlb = (self.lvlb_weights[t] * loss).mean() + loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb}) + + loss = loss_simple + self.original_elbo_weight * loss_vlb + + loss_dict.update({f'{log_prefix}/loss': loss}) + + return loss, loss_dict + + def forward(self, x, *args, **kwargs): + # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size + # assert h == img_size and w == img_size, f'height and width of image must be {img_size}' + t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long() + return self.p_losses(x, t, *args, **kwargs) + + def get_input(self, batch, k): + return batch[k] + + def shared_step(self, batch): + x = self.get_input(batch, self.first_stage_key) + loss, loss_dict = self(x) + return loss, loss_dict + + def training_step(self, batch, batch_idx): + loss, loss_dict = self.shared_step(batch) + + self.log_dict(loss_dict, prog_bar=True, + logger=True, on_step=True, on_epoch=True) + + self.log("global_step", self.global_step, + prog_bar=True, logger=True, on_step=True, on_epoch=False) + + if self.use_scheduler: + lr = self.optimizers().param_groups[0]['lr'] + self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False) + + return loss + + @torch.no_grad() + def validation_step(self, batch, batch_idx): + _, loss_dict_no_ema = self.shared_step(batch) + with self.ema_scope(): + _, loss_dict_ema = self.shared_step(batch) + loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema} + self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True) + self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True) + + def on_train_batch_end(self, *args, **kwargs): + if self.use_ema: + self.model_ema(self.model) + + def _get_rows_from_list(self, samples): + n_imgs_per_row = len(samples) + denoise_grid = rearrange(samples, 'n b c h w -> b n c h w') + denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w') + denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row) + return denoise_grid + + @torch.no_grad() + def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs): + log = dict() + x = self.get_input(batch, self.first_stage_key) + N = min(x.shape[0], N) + n_row = min(x.shape[0], n_row) + x = x.to(self.device)[:N] + log["inputs"] = x + + # get diffusion row + diffusion_row = list() + x_start = x[:n_row] + + for t in range(self.num_timesteps): + if t % self.log_every_t == 0 or t == self.num_timesteps - 1: + t = repeat(torch.tensor([t]), '1 -> b', b=n_row) + t = t.to(self.device).long() + noise = torch.randn_like(x_start) + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + diffusion_row.append(x_noisy) + + log["diffusion_row"] = self._get_rows_from_list(diffusion_row) + + if sample: + # get denoise row + with self.ema_scope("Plotting"): + samples, denoise_row = self.sample(batch_size=N, return_intermediates=True) + + log["samples"] = samples + log["denoise_row"] = self._get_rows_from_list(denoise_row) + + if return_keys: + if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0: + return log + else: + return {key: log[key] for key in return_keys} + return log + + def configure_optimizers(self): + lr = self.learning_rate + params = list(self.model.parameters()) + if self.learn_logvar: + params = params + [self.logvar] + opt = torch.optim.AdamW(params, lr=lr) + return opt + + +class LatentDiffusion(DDPM): + """main class""" + def __init__(self, + first_stage_config, + cond_stage_config, + num_timesteps_cond=None, + cond_stage_key="image", + cond_stage_trainable=False, + concat_mode=True, + cond_stage_forward=None, + conditioning_key=None, + scale_factor=1.0, + scale_by_std=False, + load_ema=True, + *args, **kwargs): + self.num_timesteps_cond = default(num_timesteps_cond, 1) + self.scale_by_std = scale_by_std + assert self.num_timesteps_cond <= kwargs['timesteps'] + # for backwards compatibility after implementation of DiffusionWrapper + if conditioning_key is None: + conditioning_key = 'concat' if concat_mode else 'crossattn' + if cond_stage_config == '__is_unconditional__': + conditioning_key = None + ckpt_path = kwargs.pop("ckpt_path", None) + ignore_keys = kwargs.pop("ignore_keys", []) + super().__init__(conditioning_key=conditioning_key, *args, load_ema=load_ema, **kwargs) + self.concat_mode = concat_mode + self.cond_stage_trainable = cond_stage_trainable + self.cond_stage_key = cond_stage_key + try: + self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1 + except: + self.num_downs = 0 + if not scale_by_std: + self.scale_factor = scale_factor + else: + self.register_buffer('scale_factor', torch.tensor(scale_factor)) + self.instantiate_first_stage(first_stage_config) + self.instantiate_cond_stage(cond_stage_config) + self.cond_stage_forward = cond_stage_forward + self.clip_denoised = False + self.bbox_tokenizer = None + + self.restarted_from_ckpt = False + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys) + self.restarted_from_ckpt = True + + if self.use_ema and not load_ema: + self.model_ema = LitEma(self.model) + print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + + def make_cond_schedule(self, ): + self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long) + ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long() + self.cond_ids[:self.num_timesteps_cond] = ids + + @rank_zero_only + @torch.no_grad() + def on_train_batch_start(self, batch, batch_idx, dataloader_idx): + # only for very first batch + if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt: + assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously' + # set rescale weight to 1./std of encodings + print("### USING STD-RESCALING ###") + x = super().get_input(batch, self.first_stage_key) + x = x.to(self.device) + encoder_posterior = self.encode_first_stage(x) + z = self.get_first_stage_encoding(encoder_posterior).detach() + del self.scale_factor + self.register_buffer('scale_factor', 1. / z.flatten().std()) + print(f"setting self.scale_factor to {self.scale_factor}") + print("### USING STD-RESCALING ###") + + def register_schedule(self, + given_betas=None, beta_schedule="linear", timesteps=1000, + linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s) + + self.shorten_cond_schedule = self.num_timesteps_cond > 1 + if self.shorten_cond_schedule: + self.make_cond_schedule() + + def instantiate_first_stage(self, config): + model = instantiate_from_config(config) + self.first_stage_model = model.eval() + self.first_stage_model.train = disabled_train + for param in self.first_stage_model.parameters(): + param.requires_grad = False + + def instantiate_cond_stage(self, config): + if not self.cond_stage_trainable: + if config == "__is_first_stage__": + print("Using first stage also as cond stage.") + self.cond_stage_model = self.first_stage_model + elif config == "__is_unconditional__": + print(f"Training {self.__class__.__name__} as an unconditional model.") + self.cond_stage_model = None + # self.be_unconditional = True + else: + model = instantiate_from_config(config) + self.cond_stage_model = model.eval() + self.cond_stage_model.train = disabled_train + for param in self.cond_stage_model.parameters(): + param.requires_grad = False + else: + assert config != '__is_first_stage__' + assert config != '__is_unconditional__' + model = instantiate_from_config(config) + self.cond_stage_model = model + + def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False): + denoise_row = [] + for zd in tqdm(samples, desc=desc): + denoise_row.append(self.decode_first_stage(zd.to(self.device), + force_not_quantize=force_no_decoder_quantization)) + n_imgs_per_row = len(denoise_row) + denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W + denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w') + denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w') + denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row) + return denoise_grid + + def get_first_stage_encoding(self, encoder_posterior): + if isinstance(encoder_posterior, DiagonalGaussianDistribution): + z = encoder_posterior.sample() + elif isinstance(encoder_posterior, torch.Tensor): + z = encoder_posterior + else: + raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented") + return self.scale_factor * z + + def get_learned_conditioning(self, c): + if self.cond_stage_forward is None: + if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode): + c = self.cond_stage_model.encode(c) + if isinstance(c, DiagonalGaussianDistribution): + c = c.mode() + else: + c = self.cond_stage_model(c) + else: + assert hasattr(self.cond_stage_model, self.cond_stage_forward) + c = getattr(self.cond_stage_model, self.cond_stage_forward)(c) + return c + + def meshgrid(self, h, w): + y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1) + x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1) + + arr = torch.cat([y, x], dim=-1) + return arr + + def delta_border(self, h, w): + """ + :param h: height + :param w: width + :return: normalized distance to image border, + wtith min distance = 0 at border and max dist = 0.5 at image center + """ + lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2) + arr = self.meshgrid(h, w) / lower_right_corner + dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0] + dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0] + edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0] + return edge_dist + + def get_weighting(self, h, w, Ly, Lx, device): + weighting = self.delta_border(h, w) + weighting = torch.clip(weighting, self.split_input_params["clip_min_weight"], + self.split_input_params["clip_max_weight"], ) + weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device) + + if self.split_input_params["tie_braker"]: + L_weighting = self.delta_border(Ly, Lx) + L_weighting = torch.clip(L_weighting, + self.split_input_params["clip_min_tie_weight"], + self.split_input_params["clip_max_tie_weight"]) + + L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device) + weighting = weighting * L_weighting + return weighting + + def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code + """ + :param x: img of size (bs, c, h, w) + :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1]) + """ + bs, nc, h, w = x.shape + + # number of crops in image + Ly = (h - kernel_size[0]) // stride[0] + 1 + Lx = (w - kernel_size[1]) // stride[1] + 1 + + if uf == 1 and df == 1: + fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) + unfold = torch.nn.Unfold(**fold_params) + + fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params) + + weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype) + normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap + weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx)) + + elif uf > 1 and df == 1: + fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) + unfold = torch.nn.Unfold(**fold_params) + + fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf), + dilation=1, padding=0, + stride=(stride[0] * uf, stride[1] * uf)) + fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2) + + weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype) + normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap + weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx)) + + elif df > 1 and uf == 1: + fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) + unfold = torch.nn.Unfold(**fold_params) + + fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df), + dilation=1, padding=0, + stride=(stride[0] // df, stride[1] // df)) + fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2) + + weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype) + normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap + weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx)) + + else: + raise NotImplementedError + + return fold, unfold, normalization, weighting + + @torch.no_grad() + def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False, + cond_key=None, return_original_cond=False, bs=None, uncond=0.05): + x = super().get_input(batch, k) + if bs is not None: + x = x[:bs] + x = x.to(self.device) + encoder_posterior = self.encode_first_stage(x) + z = self.get_first_stage_encoding(encoder_posterior).detach() + cond_key = cond_key or self.cond_stage_key + xc = super().get_input(batch, cond_key) + if bs is not None: + xc["c_crossattn"] = xc["c_crossattn"][:bs] + xc["c_concat"] = xc["c_concat"][:bs] + cond = {} + + # To support classifier-free guidance, randomly drop out only text conditioning 5%, only image conditioning 5%, and both 5%. + random = torch.rand(x.size(0), device=x.device) + prompt_mask = rearrange(random < 2 * uncond, "n -> n 1 1") + input_mask = 1 - rearrange((random >= uncond).float() * (random < 3 * uncond).float(), "n -> n 1 1 1") + + null_prompt = self.get_learned_conditioning([""]) + cond["c_crossattn"] = [torch.where(prompt_mask, null_prompt, self.get_learned_conditioning(xc["c_crossattn"]).detach())] + cond["c_concat"] = [input_mask * self.encode_first_stage((xc["c_concat"].to(self.device))).mode().detach()] + + out = [z, cond] + if return_first_stage_outputs: + xrec = self.decode_first_stage(z) + out.extend([x, xrec]) + if return_original_cond: + out.append(xc) + return out + + @torch.no_grad() + def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False): + if predict_cids: + if z.dim() == 4: + z = torch.argmax(z.exp(), dim=1).long() + z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None) + z = rearrange(z, 'b h w c -> b c h w').contiguous() + + z = 1. / self.scale_factor * z + + if hasattr(self, "split_input_params"): + if self.split_input_params["patch_distributed_vq"]: + ks = self.split_input_params["ks"] # eg. (128, 128) + stride = self.split_input_params["stride"] # eg. (64, 64) + uf = self.split_input_params["vqf"] + bs, nc, h, w = z.shape + if ks[0] > h or ks[1] > w: + ks = (min(ks[0], h), min(ks[1], w)) + print("reducing Kernel") + + if stride[0] > h or stride[1] > w: + stride = (min(stride[0], h), min(stride[1], w)) + print("reducing stride") + + fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf) + + z = unfold(z) # (bn, nc * prod(**ks), L) + # 1. Reshape to img shape + z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + + # 2. apply model loop over last dim + if isinstance(self.first_stage_model, VQModelInterface): + output_list = [self.first_stage_model.decode(z[:, :, :, :, i], + force_not_quantize=predict_cids or force_not_quantize) + for i in range(z.shape[-1])] + else: + + output_list = [self.first_stage_model.decode(z[:, :, :, :, i]) + for i in range(z.shape[-1])] + + o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L) + o = o * weighting + # Reverse 1. reshape to img shape + o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) + # stitch crops together + decoded = fold(o) + decoded = decoded / normalization # norm is shape (1, 1, h, w) + return decoded + else: + if isinstance(self.first_stage_model, VQModelInterface): + return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) + else: + return self.first_stage_model.decode(z) + + else: + if isinstance(self.first_stage_model, VQModelInterface): + return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) + else: + return self.first_stage_model.decode(z) + + # same as above but without decorator + def differentiable_decode_first_stage(self, z, predict_cids=False, force_not_quantize=False): + if predict_cids: + if z.dim() == 4: + z = torch.argmax(z.exp(), dim=1).long() + z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None) + z = rearrange(z, 'b h w c -> b c h w').contiguous() + + z = 1. / self.scale_factor * z + + if hasattr(self, "split_input_params"): + if self.split_input_params["patch_distributed_vq"]: + ks = self.split_input_params["ks"] # eg. (128, 128) + stride = self.split_input_params["stride"] # eg. (64, 64) + uf = self.split_input_params["vqf"] + bs, nc, h, w = z.shape + if ks[0] > h or ks[1] > w: + ks = (min(ks[0], h), min(ks[1], w)) + print("reducing Kernel") + + if stride[0] > h or stride[1] > w: + stride = (min(stride[0], h), min(stride[1], w)) + print("reducing stride") + + fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf) + + z = unfold(z) # (bn, nc * prod(**ks), L) + # 1. Reshape to img shape + z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + + # 2. apply model loop over last dim + if isinstance(self.first_stage_model, VQModelInterface): + output_list = [self.first_stage_model.decode(z[:, :, :, :, i], + force_not_quantize=predict_cids or force_not_quantize) + for i in range(z.shape[-1])] + else: + + output_list = [self.first_stage_model.decode(z[:, :, :, :, i]) + for i in range(z.shape[-1])] + + o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L) + o = o * weighting + # Reverse 1. reshape to img shape + o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) + # stitch crops together + decoded = fold(o) + decoded = decoded / normalization # norm is shape (1, 1, h, w) + return decoded + else: + if isinstance(self.first_stage_model, VQModelInterface): + return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) + else: + return self.first_stage_model.decode(z) + + else: + if isinstance(self.first_stage_model, VQModelInterface): + return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) + else: + return self.first_stage_model.decode(z) + + @torch.no_grad() + def encode_first_stage(self, x): + if hasattr(self, "split_input_params"): + if self.split_input_params["patch_distributed_vq"]: + ks = self.split_input_params["ks"] # eg. (128, 128) + stride = self.split_input_params["stride"] # eg. (64, 64) + df = self.split_input_params["vqf"] + self.split_input_params['original_image_size'] = x.shape[-2:] + bs, nc, h, w = x.shape + if ks[0] > h or ks[1] > w: + ks = (min(ks[0], h), min(ks[1], w)) + print("reducing Kernel") + + if stride[0] > h or stride[1] > w: + stride = (min(stride[0], h), min(stride[1], w)) + print("reducing stride") + + fold, unfold, normalization, weighting = self.get_fold_unfold(x, ks, stride, df=df) + z = unfold(x) # (bn, nc * prod(**ks), L) + # Reshape to img shape + z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + + output_list = [self.first_stage_model.encode(z[:, :, :, :, i]) + for i in range(z.shape[-1])] + + o = torch.stack(output_list, axis=-1) + o = o * weighting + + # Reverse reshape to img shape + o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) + # stitch crops together + decoded = fold(o) + decoded = decoded / normalization + return decoded + + else: + return self.first_stage_model.encode(x) + else: + return self.first_stage_model.encode(x) + + def shared_step(self, batch, **kwargs): + x, c = self.get_input(batch, self.first_stage_key) + loss = self(x, c) + return loss + + def forward(self, x, c, *args, **kwargs): + t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long() + if self.model.conditioning_key is not None: + assert c is not None + if self.cond_stage_trainable: + c = self.get_learned_conditioning(c) + if self.shorten_cond_schedule: # TODO: drop this option + tc = self.cond_ids[t].to(self.device) + c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float())) + return self.p_losses(x, c, t, *args, **kwargs) + + def _rescale_annotations(self, bboxes, crop_coordinates): # TODO: move to dataset + def rescale_bbox(bbox): + x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2]) + y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3]) + w = min(bbox[2] / crop_coordinates[2], 1 - x0) + h = min(bbox[3] / crop_coordinates[3], 1 - y0) + return x0, y0, w, h + + return [rescale_bbox(b) for b in bboxes] + + def apply_model(self, x_noisy, t, cond, return_ids=False): + + if isinstance(cond, dict): + # hybrid case, cond is exptected to be a dict + pass + else: + if not isinstance(cond, list): + cond = [cond] + key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn' + cond = {key: cond} + + if hasattr(self, "split_input_params"): + assert len(cond) == 1 # todo can only deal with one conditioning atm + assert not return_ids + ks = self.split_input_params["ks"] # eg. (128, 128) + stride = self.split_input_params["stride"] # eg. (64, 64) + + h, w = x_noisy.shape[-2:] + + fold, unfold, normalization, weighting = self.get_fold_unfold(x_noisy, ks, stride) + + z = unfold(x_noisy) # (bn, nc * prod(**ks), L) + # Reshape to img shape + z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + z_list = [z[:, :, :, :, i] for i in range(z.shape[-1])] + + if self.cond_stage_key in ["image", "LR_image", "segmentation", + 'bbox_img'] and self.model.conditioning_key: # todo check for completeness + c_key = next(iter(cond.keys())) # get key + c = next(iter(cond.values())) # get value + assert (len(c) == 1) # todo extend to list with more than one elem + c = c[0] # get element + + c = unfold(c) + c = c.view((c.shape[0], -1, ks[0], ks[1], c.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + + cond_list = [{c_key: [c[:, :, :, :, i]]} for i in range(c.shape[-1])] + + elif self.cond_stage_key == 'coordinates_bbox': + assert 'original_image_size' in self.split_input_params, 'BoudingBoxRescaling is missing original_image_size' + + # assuming padding of unfold is always 0 and its dilation is always 1 + n_patches_per_row = int((w - ks[0]) / stride[0] + 1) + full_img_h, full_img_w = self.split_input_params['original_image_size'] + # as we are operating on latents, we need the factor from the original image size to the + # spatial latent size to properly rescale the crops for regenerating the bbox annotations + num_downs = self.first_stage_model.encoder.num_resolutions - 1 + rescale_latent = 2 ** (num_downs) + + # get top left postions of patches as conforming for the bbbox tokenizer, therefore we + # need to rescale the tl patch coordinates to be in between (0,1) + tl_patch_coordinates = [(rescale_latent * stride[0] * (patch_nr % n_patches_per_row) / full_img_w, + rescale_latent * stride[1] * (patch_nr // n_patches_per_row) / full_img_h) + for patch_nr in range(z.shape[-1])] + + # patch_limits are tl_coord, width and height coordinates as (x_tl, y_tl, h, w) + patch_limits = [(x_tl, y_tl, + rescale_latent * ks[0] / full_img_w, + rescale_latent * ks[1] / full_img_h) for x_tl, y_tl in tl_patch_coordinates] + # patch_values = [(np.arange(x_tl,min(x_tl+ks, 1.)),np.arange(y_tl,min(y_tl+ks, 1.))) for x_tl, y_tl in tl_patch_coordinates] + + # tokenize crop coordinates for the bounding boxes of the respective patches + patch_limits_tknzd = [torch.LongTensor(self.bbox_tokenizer._crop_encoder(bbox))[None].to(self.device) + for bbox in patch_limits] # list of length l with tensors of shape (1, 2) + print(patch_limits_tknzd[0].shape) + # cut tknzd crop position from conditioning + assert isinstance(cond, dict), 'cond must be dict to be fed into model' + cut_cond = cond['c_crossattn'][0][..., :-2].to(self.device) + print(cut_cond.shape) + + adapted_cond = torch.stack([torch.cat([cut_cond, p], dim=1) for p in patch_limits_tknzd]) + adapted_cond = rearrange(adapted_cond, 'l b n -> (l b) n') + print(adapted_cond.shape) + adapted_cond = self.get_learned_conditioning(adapted_cond) + print(adapted_cond.shape) + adapted_cond = rearrange(adapted_cond, '(l b) n d -> l b n d', l=z.shape[-1]) + print(adapted_cond.shape) + + cond_list = [{'c_crossattn': [e]} for e in adapted_cond] + + else: + cond_list = [cond for i in range(z.shape[-1])] # Todo make this more efficient + + # apply model by loop over crops + output_list = [self.model(z_list[i], t, **cond_list[i]) for i in range(z.shape[-1])] + assert not isinstance(output_list[0], + tuple) # todo cant deal with multiple model outputs check this never happens + + o = torch.stack(output_list, axis=-1) + o = o * weighting + # Reverse reshape to img shape + o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) + # stitch crops together + x_recon = fold(o) / normalization + + else: + x_recon = self.model(x_noisy, t, **cond) + + if isinstance(x_recon, tuple) and not return_ids: + return x_recon[0] + else: + return x_recon + + def _predict_eps_from_xstart(self, x_t, t, pred_xstart): + return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \ + extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + + def _prior_bpd(self, x_start): + """ + Get the prior KL term for the variational lower-bound, measured in + bits-per-dim. + This term can't be optimized, as it only depends on the encoder. + :param x_start: the [N x C x ...] tensor of inputs. + :return: a batch of [N] KL values (in bits), one per batch element. + """ + batch_size = x_start.shape[0] + t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) + qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) + kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0) + return mean_flat(kl_prior) / np.log(2.0) + + def p_losses(self, x_start, cond, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + model_output = self.apply_model(x_noisy, t, cond) + + loss_dict = {} + prefix = 'train' if self.training else 'val' + + if self.parameterization == "x0": + target = x_start + elif self.parameterization == "eps": + target = noise + else: + raise NotImplementedError() + + loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3]) + loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()}) + + logvar_t = self.logvar[t].to(self.device) + loss = loss_simple / torch.exp(logvar_t) + logvar_t + # loss = loss_simple / torch.exp(self.logvar) + self.logvar + if self.learn_logvar: + loss_dict.update({f'{prefix}/loss_gamma': loss.mean()}) + loss_dict.update({'logvar': self.logvar.data.mean()}) + + loss = self.l_simple_weight * loss.mean() + + loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3)) + loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean() + loss_dict.update({f'{prefix}/loss_vlb': loss_vlb}) + loss += (self.original_elbo_weight * loss_vlb) + loss_dict.update({f'{prefix}/loss': loss}) + + return loss, loss_dict + + def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False, + return_x0=False, score_corrector=None, corrector_kwargs=None): + t_in = t + model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids) + + if score_corrector is not None: + assert self.parameterization == "eps" + model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs) + + if return_codebook_ids: + model_out, logits = model_out + + if self.parameterization == "eps": + x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) + elif self.parameterization == "x0": + x_recon = model_out + else: + raise NotImplementedError() + + if clip_denoised: + x_recon.clamp_(-1., 1.) + if quantize_denoised: + x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon) + model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) + if return_codebook_ids: + return model_mean, posterior_variance, posterior_log_variance, logits + elif return_x0: + return model_mean, posterior_variance, posterior_log_variance, x_recon + else: + return model_mean, posterior_variance, posterior_log_variance + + @torch.no_grad() + def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False, + return_codebook_ids=False, quantize_denoised=False, return_x0=False, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None): + b, *_, device = *x.shape, x.device + outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised, + return_codebook_ids=return_codebook_ids, + quantize_denoised=quantize_denoised, + return_x0=return_x0, + score_corrector=score_corrector, corrector_kwargs=corrector_kwargs) + if return_codebook_ids: + raise DeprecationWarning("Support dropped.") + model_mean, _, model_log_variance, logits = outputs + elif return_x0: + model_mean, _, model_log_variance, x0 = outputs + else: + model_mean, _, model_log_variance = outputs + + noise = noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + # no noise when t == 0 + nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) + + if return_codebook_ids: + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1) + if return_x0: + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0 + else: + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + + @torch.no_grad() + def progressive_denoising(self, cond, shape, verbose=True, callback=None, quantize_denoised=False, + img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0., + score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None, + log_every_t=None): + if not log_every_t: + log_every_t = self.log_every_t + timesteps = self.num_timesteps + if batch_size is not None: + b = batch_size if batch_size is not None else shape[0] + shape = [batch_size] + list(shape) + else: + b = batch_size = shape[0] + if x_T is None: + img = torch.randn(shape, device=self.device) + else: + img = x_T + intermediates = [] + if cond is not None: + if isinstance(cond, dict): + cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else + list(map(lambda x: x[:batch_size], cond[key])) for key in cond} + else: + cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] + + if start_T is not None: + timesteps = min(timesteps, start_T) + iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation', + total=timesteps) if verbose else reversed( + range(0, timesteps)) + if type(temperature) == float: + temperature = [temperature] * timesteps + + for i in iterator: + ts = torch.full((b,), i, device=self.device, dtype=torch.long) + if self.shorten_cond_schedule: + assert self.model.conditioning_key != 'hybrid' + tc = self.cond_ids[ts].to(cond.device) + cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond)) + + img, x0_partial = self.p_sample(img, cond, ts, + clip_denoised=self.clip_denoised, + quantize_denoised=quantize_denoised, return_x0=True, + temperature=temperature[i], noise_dropout=noise_dropout, + score_corrector=score_corrector, corrector_kwargs=corrector_kwargs) + if mask is not None: + assert x0 is not None + img_orig = self.q_sample(x0, ts) + img = img_orig * mask + (1. - mask) * img + + if i % log_every_t == 0 or i == timesteps - 1: + intermediates.append(x0_partial) + if callback: callback(i) + if img_callback: img_callback(img, i) + return img, intermediates + + @torch.no_grad() + def p_sample_loop(self, cond, shape, return_intermediates=False, + x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False, + mask=None, x0=None, img_callback=None, start_T=None, + log_every_t=None): + + if not log_every_t: + log_every_t = self.log_every_t + device = self.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + intermediates = [img] + if timesteps is None: + timesteps = self.num_timesteps + + if start_T is not None: + timesteps = min(timesteps, start_T) + iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed( + range(0, timesteps)) + + if mask is not None: + assert x0 is not None + assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match + + for i in iterator: + ts = torch.full((b,), i, device=device, dtype=torch.long) + if self.shorten_cond_schedule: + assert self.model.conditioning_key != 'hybrid' + tc = self.cond_ids[ts].to(cond.device) + cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond)) + + img = self.p_sample(img, cond, ts, + clip_denoised=self.clip_denoised, + quantize_denoised=quantize_denoised) + if mask is not None: + img_orig = self.q_sample(x0, ts) + img = img_orig * mask + (1. - mask) * img + + if i % log_every_t == 0 or i == timesteps - 1: + intermediates.append(img) + if callback: callback(i) + if img_callback: img_callback(img, i) + + if return_intermediates: + return img, intermediates + return img + + @torch.no_grad() + def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None, + verbose=True, timesteps=None, quantize_denoised=False, + mask=None, x0=None, shape=None,**kwargs): + if shape is None: + shape = (batch_size, self.channels, self.image_size, self.image_size) + if cond is not None: + if isinstance(cond, dict): + cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else + list(map(lambda x: x[:batch_size], cond[key])) for key in cond} + else: + cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] + return self.p_sample_loop(cond, + shape, + return_intermediates=return_intermediates, x_T=x_T, + verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised, + mask=mask, x0=x0) + + @torch.no_grad() + def sample_log(self,cond,batch_size,ddim, ddim_steps,**kwargs): + + if ddim: + ddim_sampler = DDIMSampler(self) + shape = (self.channels, self.image_size, self.image_size) + samples, intermediates =ddim_sampler.sample(ddim_steps,batch_size, + shape,cond,verbose=False,**kwargs) + + else: + samples, intermediates = self.sample(cond=cond, batch_size=batch_size, + return_intermediates=True,**kwargs) + + return samples, intermediates + + + @torch.no_grad() + def log_images(self, batch, N=4, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None, + quantize_denoised=True, inpaint=False, plot_denoise_rows=False, plot_progressive_rows=False, + plot_diffusion_rows=False, **kwargs): + + use_ddim = False + + log = dict() + z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, + return_first_stage_outputs=True, + force_c_encode=True, + return_original_cond=True, + bs=N, uncond=0) + N = min(x.shape[0], N) + n_row = min(x.shape[0], n_row) + log["inputs"] = x + log["reals"] = xc["c_concat"] + log["reconstruction"] = xrec + if self.model.conditioning_key is not None: + if hasattr(self.cond_stage_model, "decode"): + xc = self.cond_stage_model.decode(c) + log["conditioning"] = xc + elif self.cond_stage_key in ["caption"]: + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["caption"]) + log["conditioning"] = xc + elif self.cond_stage_key == 'class_label': + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"]) + log['conditioning'] = xc + elif isimage(xc): + log["conditioning"] = xc + if ismap(xc): + log["original_conditioning"] = self.to_rgb(xc) + + if plot_diffusion_rows: + # get diffusion row + diffusion_row = list() + z_start = z[:n_row] + for t in range(self.num_timesteps): + if t % self.log_every_t == 0 or t == self.num_timesteps - 1: + t = repeat(torch.tensor([t]), '1 -> b', b=n_row) + t = t.to(self.device).long() + noise = torch.randn_like(z_start) + z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise) + diffusion_row.append(self.decode_first_stage(z_noisy)) + + diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W + diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w') + diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w') + diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0]) + log["diffusion_row"] = diffusion_grid + + if sample: + # get denoise row + with self.ema_scope("Plotting"): + samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, + ddim_steps=ddim_steps,eta=ddim_eta) + # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True) + x_samples = self.decode_first_stage(samples) + log["samples"] = x_samples + if plot_denoise_rows: + denoise_grid = self._get_denoise_row_from_list(z_denoise_row) + log["denoise_row"] = denoise_grid + + if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance( + self.first_stage_model, IdentityFirstStage): + # also display when quantizing x0 while sampling + with self.ema_scope("Plotting Quantized Denoised"): + samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, + ddim_steps=ddim_steps,eta=ddim_eta, + quantize_denoised=True) + # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True, + # quantize_denoised=True) + x_samples = self.decode_first_stage(samples.to(self.device)) + log["samples_x0_quantized"] = x_samples + + if inpaint: + # make a simple center square + b, h, w = z.shape[0], z.shape[2], z.shape[3] + mask = torch.ones(N, h, w).to(self.device) + # zeros will be filled in + mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0. + mask = mask[:, None, ...] + with self.ema_scope("Plotting Inpaint"): + + samples, _ = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, eta=ddim_eta, + ddim_steps=ddim_steps, x0=z[:N], mask=mask) + x_samples = self.decode_first_stage(samples.to(self.device)) + log["samples_inpainting"] = x_samples + log["mask"] = mask + + # outpaint + with self.ema_scope("Plotting Outpaint"): + samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,eta=ddim_eta, + ddim_steps=ddim_steps, x0=z[:N], mask=mask) + x_samples = self.decode_first_stage(samples.to(self.device)) + log["samples_outpainting"] = x_samples + + if plot_progressive_rows: + with self.ema_scope("Plotting Progressives"): + img, progressives = self.progressive_denoising(c, + shape=(self.channels, self.image_size, self.image_size), + batch_size=N) + prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation") + log["progressive_row"] = prog_row + + if return_keys: + if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0: + return log + else: + return {key: log[key] for key in return_keys} + return log + + def configure_optimizers(self): + lr = self.learning_rate + params = list(self.model.parameters()) + if self.cond_stage_trainable: + print(f"{self.__class__.__name__}: Also optimizing conditioner params!") + params = params + list(self.cond_stage_model.parameters()) + if self.learn_logvar: + print('Diffusion model optimizing logvar') + params.append(self.logvar) + opt = torch.optim.AdamW(params, lr=lr) + if self.use_scheduler: + assert 'target' in self.scheduler_config + scheduler = instantiate_from_config(self.scheduler_config) + + print("Setting up LambdaLR scheduler...") + scheduler = [ + { + 'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule), + 'interval': 'step', + 'frequency': 1 + }] + return [opt], scheduler + return opt + + @torch.no_grad() + def to_rgb(self, x): + x = x.float() + if not hasattr(self, "colorize"): + self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x) + x = nn.functional.conv2d(x, weight=self.colorize) + x = 2. * (x - x.min()) / (x.max() - x.min()) - 1. + return x + + +class DiffusionWrapper(pl.LightningModule): + def __init__(self, diff_model_config, conditioning_key): + super().__init__() + self.diffusion_model = instantiate_from_config(diff_model_config) + self.conditioning_key = conditioning_key + assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm'] + + def forward(self, x, t, c_concat: list = None, c_crossattn: list = None): + if self.conditioning_key is None: + out = self.diffusion_model(x, t) + elif self.conditioning_key == 'concat': + xc = torch.cat([x] + c_concat, dim=1) + out = self.diffusion_model(xc, t) + elif self.conditioning_key == 'crossattn': + cc = torch.cat(c_crossattn, 1) + out = self.diffusion_model(x, t, context=cc) + elif self.conditioning_key == 'hybrid': + xc = torch.cat([x] + c_concat, dim=1) + cc = torch.cat(c_crossattn, 1) + out = self.diffusion_model(xc, t, context=cc) + elif self.conditioning_key == 'adm': + cc = c_crossattn[0] + out = self.diffusion_model(x, t, y=cc) + else: + raise NotImplementedError() + + return out + + +class Layout2ImgDiffusion(LatentDiffusion): + # TODO: move all layout-specific hacks to this class + def __init__(self, cond_stage_key, *args, **kwargs): + assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"' + super().__init__(cond_stage_key=cond_stage_key, *args, **kwargs) + + def log_images(self, batch, N=8, *args, **kwargs): + logs = super().log_images(batch=batch, N=N, *args, **kwargs) + + key = 'train' if self.training else 'validation' + dset = self.trainer.datamodule.datasets[key] + mapper = dset.conditional_builders[self.cond_stage_key] + + bbox_imgs = [] + map_fn = lambda catno: dset.get_textual_label(dset.get_category_id(catno)) + for tknzd_bbox in batch[self.cond_stage_key][:N]: + bboximg = mapper.plot(tknzd_bbox.detach().cpu(), map_fn, (256, 256)) + bbox_imgs.append(bboximg) + + cond_img = torch.stack(bbox_imgs, dim=0) + logs['bbox_image'] = cond_img + return logs diff --git a/modules/sd_hijack_ip2p.py b/modules/sd_hijack_ip2p.py new file mode 100644 index 00000000..635f015f --- /dev/null +++ b/modules/sd_hijack_ip2p.py @@ -0,0 +1,13 @@ +import collections +import os.path +import sys +import gc +import time + +def should_hijack_ip2p(checkpoint_info): + from modules import sd_models + + ckpt_basename = os.path.basename(checkpoint_info.filename).lower() + cfg_basename = os.path.basename(sd_models.find_checkpoint_config(checkpoint_info)).lower() + + return "pix2pix" in ckpt_basename and not "pix2pix" in cfg_basename diff --git a/modules/sd_models.py b/modules/sd_models.py index 12083848..cddc2343 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -17,6 +17,7 @@ from ldm.util import instantiate_from_config from modules import shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes from modules.paths import models_path from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting +from modules.sd_hijack_ip2p import should_hijack_ip2p model_dir = "Stable-diffusion" model_path = os.path.abspath(os.path.join(models_path, model_dir)) @@ -365,6 +366,15 @@ def load_model(checkpoint_info=None): sd_config.model.params.unet_config.params.in_channels = 9 sd_config.model.params.finetune_keys = None + if should_hijack_ip2p(checkpoint_info): + sd_config.model.target = "modules.models.diffusion.ddpm_edit.LatentDiffusion" + sd_config.model.params.conditioning_key = "hybrid" + sd_config.model.params.first_stage_key = "edited" + sd_config.model.params.cond_stage_key = "edit" + sd_config.model.params.image_size = 16 + sd_config.model.params.unet_config.params.in_channels = 8 + sd_config.model.params.unet_config.params.out_channels = 4 + if not hasattr(sd_config.model.params, "use_ema"): sd_config.model.params.use_ema = False @@ -429,7 +439,7 @@ def reload_model_weights(sd_model=None, info=None): checkpoint_config = find_checkpoint_config(current_checkpoint_info) - if current_checkpoint_info is None or checkpoint_config != find_checkpoint_config(checkpoint_info) or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info): + if current_checkpoint_info is None or checkpoint_config != find_checkpoint_config(checkpoint_info) or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info) or should_hijack_ip2p(checkpoint_info) != should_hijack_ip2p(sd_model.sd_checkpoint_info): del sd_model checkpoints_loaded.clear() load_model(checkpoint_info) From bd9b55ee908c43fb1b654b3a3a1320545023ce1c Mon Sep 17 00:00:00 2001 From: Kyle Date: Wed, 25 Jan 2023 09:41:41 -0500 Subject: [PATCH 052/151] Update requirements transformers==4.25.1 Update requirement for transformers to version 4.25.1 to allow instruct-pix2pix demo code to work --- requirements.txt | 2 +- requirements_versions.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index a4be1ec3..6d53f089 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,7 +16,7 @@ pytorch_lightning==1.7.7 realesrgan scikit-image>=0.19 timm==0.4.12 -transformers==4.19.2 +transformers==4.25.1 torch einops jsonmerge diff --git a/requirements_versions.txt b/requirements_versions.txt index 135908be..eaa08806 100644 --- a/requirements_versions.txt +++ b/requirements_versions.txt @@ -1,5 +1,5 @@ blendmodes==2022 -transformers==4.19.2 +transformers==4.25.1 accelerate==0.12.0 basicsr==1.4.2 gfpgan==1.3.8 From 57c1baa774d07060af0abbd2974c5f36c8cb63ac Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 25 Jan 2023 18:56:23 +0300 Subject: [PATCH 053/151] change to code for live preview fix on OSX to be bit more obvious --- modules/processing.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index 3bd590ba..57c3db1b 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -568,8 +568,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: with devices.autocast(): p.init(p.all_prompts, p.all_seeds, p.all_subseeds) - if shared.opts.live_previews_enable and sd_samplers.approximation_indexes.get(shared.opts.show_progress_type, 0) == 1: - # preload approx nn model before sampling for a more deterministic result + # for OSX, loading the model during sampling changes the generated picture, so it is loaded here + if shared.opts.live_previews_enable and opts.show_progress_type == "Approx NN": sd_vae_approx.model() if not p.disable_extra_networks: From 635499e8329dfd8c4c5ccca180881867f34a9f36 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 25 Jan 2023 19:42:26 +0300 Subject: [PATCH 054/151] add pix2pix credits --- README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index a5611671..c6bd6f27 100644 --- a/README.md +++ b/README.md @@ -155,7 +155,8 @@ Licenses for borrowed code can be found in `Settings -> Licenses` screen, and al - Idea for Composable Diffusion - https://github.com/energy-based-model/Compositional-Visual-Generation-with-Composable-Diffusion-Models-PyTorch - xformers - https://github.com/facebookresearch/xformers - DeepDanbooru - interrogator for anime diffusers https://github.com/KichangKim/DeepDanbooru +- Sampling in float32 precision from a float16 UNet - marunine for the idea, Birch-san for the example Diffusers implementation (https://github.com/Birch-san/diffusers-play/tree/92feee6) +- Instruct pix2pix - Tim Brooks (star), Aleksander Holynski (star), Alexei A. Efros (no star) - https://github.com/timothybrooks/instruct-pix2pix - Security advice - RyotaK - Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user. -- Sampling in float32 precision from a float16 UNet - marunine for the idea, Birch-san for the example Diffusers implementation (https://github.com/Birch-san/diffusers-play/tree/92feee6) - (You) From e179b6098ac1b1ce9645fef5bd9fd0bc9b918f30 Mon Sep 17 00:00:00 2001 From: "Alex \"mcmonkey\" Goodwin" Date: Wed, 25 Jan 2023 08:48:40 -0800 Subject: [PATCH 055/151] allow symlinks in the textual inversion embeddings folder --- modules/textual_inversion/textual_inversion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 4e90f690..6cf00e65 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -194,7 +194,7 @@ class EmbeddingDatabase: if not os.path.isdir(embdir.path): return - for root, dirs, fns in os.walk(embdir.path): + for root, dirs, fns in os.walk(embdir.path, followlinks=True): for fn in fns: try: fullfn = os.path.join(root, fn) From 789d47f832a5c921dbbdd0a657dff9bca7f78d94 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 25 Jan 2023 19:55:31 +0300 Subject: [PATCH 056/151] make clicking extra networks button one more time close the extra networks UI --- modules/ui_extra_networks.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 8b4f97f8..c6ff889a 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -117,8 +117,13 @@ def create_ui(container, button, tabname): ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False) ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=tabname+"_preview_filename", visible=False) - button.click(fn=lambda: gr.update(visible=True), inputs=[], outputs=[container]) - button_close.click(fn=lambda: gr.update(visible=False), inputs=[], outputs=[container]) + def toggle_visibility(is_visible): + is_visible = not is_visible + return is_visible, gr.update(visible=is_visible) + + state_visible = gr.State(value=False) + button.click(fn=toggle_visibility, inputs=[state_visible], outputs=[state_visible, container]) + button_close.click(fn=toggle_visibility, inputs=[state_visible], outputs=[state_visible, container]) def refresh(): res = [] From e425b9812b067073eb6edfafac689735f5391b45 Mon Sep 17 00:00:00 2001 From: Spaceginner Date: Wed, 25 Jan 2023 22:07:48 +0500 Subject: [PATCH 057/151] Added Python version check --- launch.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/launch.py b/launch.py index 9d6f4a8c..86b4a32b 100644 --- a/launch.py +++ b/launch.py @@ -17,6 +17,17 @@ stored_commit_hash = None skip_install = False +def check_python_version(): + version = sys.version_info + version_range = None + if os.name == "nt": + version_range = range(7, 11) + else: + version_range = range(7, 12) + + assert version.major == 3 and version.minor in version_range, "Unsupported Python version, please use Python 3.10.x instead. You can download latest release as of 25th January (3.10.9) from here: https://www.python.org/downloads/release/python-3109/" + + def commit_hash(): global stored_commit_hash @@ -321,5 +332,6 @@ def start(): if __name__ == "__main__": + check_python_version() prepare_environment() start() From 15e89ef0f6f22f823c19592a401b9e4ee477258c Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 25 Jan 2023 20:11:01 +0300 Subject: [PATCH 058/151] fix for unet hijack breaking the train tab --- modules/sd_hijack_unet.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/modules/sd_hijack_unet.py b/modules/sd_hijack_unet.py index 88c94e54..a6ee577c 100644 --- a/modules/sd_hijack_unet.py +++ b/modules/sd_hijack_unet.py @@ -36,8 +36,11 @@ th = TorchHijackForUnet() # Below are monkey patches to enable upcasting a float16 UNet for float32 sampling def apply_model(orig_func, self, x_noisy, t, cond, **kwargs): - for y in cond.keys(): - cond[y] = [x.to(devices.dtype_unet) if isinstance(x, torch.Tensor) else x for x in cond[y]] + + if isinstance(cond, dict): + for y in cond.keys(): + cond[y] = [x.to(devices.dtype_unet) if isinstance(x, torch.Tensor) else x for x in cond[y]] + with devices.autocast(): return orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs).float() From 57096823fadbc18b33d9b89d2d3a02d5ebba29f4 Mon Sep 17 00:00:00 2001 From: Spaceginner Date: Wed, 25 Jan 2023 22:33:35 +0500 Subject: [PATCH 059/151] Remove a stacktrace from an assertion to not scare people --- launch.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/launch.py b/launch.py index 86b4a32b..cf747e72 100644 --- a/launch.py +++ b/launch.py @@ -25,7 +25,10 @@ def check_python_version(): else: version_range = range(7, 12) - assert version.major == 3 and version.minor in version_range, "Unsupported Python version, please use Python 3.10.x instead. You can download latest release as of 25th January (3.10.9) from here: https://www.python.org/downloads/release/python-3109/" + try: + assert version.major == 3 and version.minor in version_range, "Unsupported Python version, please use Python 3.10.x instead. You can download latest release as of 25th January (3.10.9) from here: https://www.python.org/downloads/release/python-3109/" + except AssertionError as e: + print(e) def commit_hash(): From 2de99d62dd80123bf2d7dcbb2c4970fad5d92d42 Mon Sep 17 00:00:00 2001 From: Spaceginner Date: Wed, 25 Jan 2023 22:38:28 +0500 Subject: [PATCH 060/151] some clarification --- launch.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/launch.py b/launch.py index cf747e72..4608bc81 100644 --- a/launch.py +++ b/launch.py @@ -26,9 +26,10 @@ def check_python_version(): version_range = range(7, 12) try: - assert version.major == 3 and version.minor in version_range, "Unsupported Python version, please use Python 3.10.x instead. You can download latest release as of 25th January (3.10.9) from here: https://www.python.org/downloads/release/python-3109/" + assert version.major == 3 and version.minor in version_range, "Unsupported Python version, please use Python 3.10.x instead. You can download latest release as of 25th January (3.10.9) from here: https://www.python.org/downloads/release/python-3109/. Please, make sure to first delete current version of Python first." except AssertionError as e: print(e) + sys.exit(-1) def commit_hash(): From 0cc5f380d5a21625413554a6a64b97172b36d64a Mon Sep 17 00:00:00 2001 From: Spaceginner Date: Wed, 25 Jan 2023 22:41:51 +0500 Subject: [PATCH 061/151] even more clarifications(?) i have no idea what commit message should be --- launch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/launch.py b/launch.py index 4608bc81..801e0371 100644 --- a/launch.py +++ b/launch.py @@ -26,7 +26,7 @@ def check_python_version(): version_range = range(7, 12) try: - assert version.major == 3 and version.minor in version_range, "Unsupported Python version, please use Python 3.10.x instead. You can download latest release as of 25th January (3.10.9) from here: https://www.python.org/downloads/release/python-3109/. Please, make sure to first delete current version of Python first." + assert version.major == 3 and version.minor in version_range, "Unsupported Python version, please use Python 3.10.x instead. You can download latest release as of 25th January (3.10.9) from here: https://www.python.org/downloads/release/python-3109/. Please, make sure to first delete current version of Python first and delete `venv` folder inside of WebUI's folder, too." except AssertionError as e: print(e) sys.exit(-1) From f5d73b6a6646b51027c8e6f6c6154f21b58d6af2 Mon Sep 17 00:00:00 2001 From: Spaceginner Date: Wed, 25 Jan 2023 22:56:09 +0500 Subject: [PATCH 062/151] Fixed typo --- launch.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/launch.py b/launch.py index 801e0371..e39c68e7 100644 --- a/launch.py +++ b/launch.py @@ -21,9 +21,9 @@ def check_python_version(): version = sys.version_info version_range = None if os.name == "nt": - version_range = range(7, 11) + version_range = range(7 + 1, 10 + 1) else: - version_range = range(7, 12) + version_range = range(7 + 1, 11 + 1) try: assert version.major == 3 and version.minor in version_range, "Unsupported Python version, please use Python 3.10.x instead. You can download latest release as of 25th January (3.10.9) from here: https://www.python.org/downloads/release/python-3109/. Please, make sure to first delete current version of Python first and delete `venv` folder inside of WebUI's folder, too." From e0df864b8c1f99d7d65d56c6ac8a1e5e314dddba Mon Sep 17 00:00:00 2001 From: brkirch Date: Wed, 25 Jan 2023 13:19:06 -0500 Subject: [PATCH 063/151] Update arguments to use --upcast-sampling --- webui-macos-env.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/webui-macos-env.sh b/webui-macos-env.sh index 95ca9c55..fa187dd1 100644 --- a/webui-macos-env.sh +++ b/webui-macos-env.sh @@ -10,7 +10,7 @@ then fi export install_dir="$HOME" -export COMMANDLINE_ARGS="--skip-torch-cuda-test --no-half --use-cpu interrogate" +export COMMANDLINE_ARGS="--skip-torch-cuda-test --upcast-sampling --use-cpu interrogate" export TORCH_COMMAND="pip install torch==1.12.1 torchvision==0.13.1" export K_DIFFUSION_REPO="https://github.com/brkirch/k-diffusion.git" export K_DIFFUSION_COMMIT_HASH="51c9778f269cedb55a4d88c79c0246d35bdadb71" From d1d6ce29831d1b067801c3206f314258de88f683 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 25 Jan 2023 23:25:25 +0300 Subject: [PATCH 064/151] add edit_image_conditioning from my earlier edits in case there's an attempt to inegrate pix2pix properly this allows to use pix2pix model in img2img though it won't work well this way --- modules/processing.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/modules/processing.py b/modules/processing.py index 9e5a2f38..cb41288a 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -185,7 +185,12 @@ class StableDiffusionProcessing: conditioning = 2. * (conditioning - depth_min) / (depth_max - depth_min) - 1. return conditioning - def inpainting_image_conditioning(self, source_image, latent_image, image_mask = None): + def edit_image_conditioning(self, source_image): + conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image)) + + return conditioning_image + + def inpainting_image_conditioning(self, source_image, latent_image, image_mask=None): self.is_using_inpainting_conditioning = True # Handle the different mask inputs @@ -228,6 +233,9 @@ class StableDiffusionProcessing: if isinstance(self.sd_model, LatentDepth2ImageDiffusion): return self.depth2img_image_conditioning(source_image.float() if devices.unet_needs_upcast else source_image) + if self.sd_model.cond_stage_key == "edit": + return self.edit_image_conditioning(source_image) + if self.sampler.conditioning_key in {'hybrid', 'concat'}: return self.inpainting_image_conditioning(source_image.float() if devices.unet_needs_upcast else source_image, latent_image, image_mask=image_mask) From 6cff4401824299a983c8e13424018efc347b4a2b Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 25 Jan 2023 23:25:40 +0300 Subject: [PATCH 065/151] fix prompt editing break after first batch in img2img --- modules/sd_samplers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index 6261d1f7..a7910b56 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -454,7 +454,7 @@ class KDiffusionSampler: def initialize(self, p): self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None - self.model_wrap.step = 0 + self.model_wrap_cfg.step = 0 self.eta = p.eta or opts.eta_ancestral k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else []) From d82d471bf7797fe09dbc6f3a6ac1c2a76142c8f1 Mon Sep 17 00:00:00 2001 From: Vladimir Repin <32306715+mezotaken@users.noreply.github.com> Date: Thu, 26 Jan 2023 02:09:14 +0300 Subject: [PATCH 066/151] Ask user to clarify conditions --- .github/ISSUE_TEMPLATE/bug_report.yml | 29 +++++++++++++++++++++------ 1 file changed, 23 insertions(+), 6 deletions(-) diff --git a/.github/ISSUE_TEMPLATE/bug_report.yml b/.github/ISSUE_TEMPLATE/bug_report.yml index ed372f22..7d435297 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.yml +++ b/.github/ISSUE_TEMPLATE/bug_report.yml @@ -37,20 +37,20 @@ body: id: what-should attributes: label: What should have happened? - description: tell what you think the normal behavior should be + description: Tell what you think the normal behavior should be validations: required: true - type: input id: commit attributes: label: Commit where the problem happens - description: Which commit are you running ? (Do not write *Latest version/repo/commit*, as this means nothing and will have changed by the time we read your issue. Rather, copy the **Commit hash** shown in the cmd/terminal when you launch the UI) + description: Which commit are you running ? (Do not write *Latest version/repo/commit*, as this means nothing and will have changed by the time we read your issue. Rather, copy the **Commit** link at the bottom of the UI, or from the cmd/terminal if you can't launch it.) validations: required: true - type: dropdown id: platforms attributes: - label: What platforms do you use to access UI ? + label: What platforms do you use to access the UI ? multiple: true options: - Windows @@ -74,10 +74,27 @@ body: id: cmdargs attributes: label: Command Line Arguments - description: Are you using any launching parameters/command line arguments (modified webui-user.py) ? If yes, please write them below + description: Are you using any launching parameters/command line arguments (modified webui-user .bat/.sh) ? If yes, please write them below. Write "No" otherwise. render: Shell + validations: + required: true + - type: textarea + id: extensions + attributes: + label: List of extensions + description: Are you using any extensions other than built-ins? If yes, provide a list, you can copy it at "Extensions" tab. Write "No" otherwise. + validations: + required: true + - type: textarea + id: logs + attributes: + label: Console logs + description: Please provide **full** cmd/terminal logs from the moment you started UI to the end of it, after your bug happened. If it's very long, provide a link to pastebin or similar service. + render: Shell + validations: + required: true - type: textarea id: misc attributes: - label: Additional information, context and logs - description: Please provide us with any relevant additional info, context or log output. + label: Additional information + description: Please provide us with any relevant additional info or context. From e57b5f7c5560c49fbaf05e6bea326478222cb3e6 Mon Sep 17 00:00:00 2001 From: EllangoK Date: Wed, 25 Jan 2023 22:36:14 -0500 Subject: [PATCH 067/151] re_param captures quotes with commas properly and removes unnecessary regex --- modules/generation_parameters_copypaste.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index 46e12dc6..13d0874d 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -11,9 +11,8 @@ from modules import shared, ui_tempdir, script_callbacks import tempfile from PIL import Image -re_param_code = r'\s*([\w ]+):\s*("(?:\\|\"|[^\"])+"|[^,]*)(?:,|$)' +re_param_code = r'\s*([\w ]+):\s*(\"[^\"]*\"|[^,]+)' re_param = re.compile(re_param_code) -re_params = re.compile(r"^(?:" + re_param_code + "){3,}$") re_imagesize = re.compile(r"^(\d+)x(\d+)$") re_hypernet_hash = re.compile("\(([0-9a-f]+)\)$") type_of_gr_update = type(gr.update()) @@ -243,7 +242,7 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model done_with_prompt = False *lines, lastline = x.strip().split("\n") - if not re_params.match(lastline): + if not re_param.match(lastline): lines.append(lastline) lastline = '' @@ -262,6 +261,7 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model res["Negative prompt"] = negative_prompt for k, v in re_param.findall(lastline): + v = v[1:-1] if v[0] == '"' and v[-1] == '"' else v m = re_imagesize.match(v) if m is not None: res[k+"-1"] = m.group(1) From 4d634dc592ffdbd4ebb2f1acfb9a63f5e26e4deb Mon Sep 17 00:00:00 2001 From: EllangoK Date: Thu, 26 Jan 2023 00:18:41 -0500 Subject: [PATCH 068/151] adds components to infotext_fields allows for loading script params --- modules/scripts.py | 14 ++++++++++++++ scripts/xyz_grid.py | 10 ++++++++++ 2 files changed, 24 insertions(+) diff --git a/modules/scripts.py b/modules/scripts.py index 03907a63..eefdfdd4 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -330,6 +330,20 @@ class ScriptRunner: outputs=[script.group for script in self.selectable_scripts] ) + self.script_load_ctr = 0 + def onload_script_visibility(params): + title = params.get('Script', None) + if title: + title_index = self.titles.index(title) + visibility = title_index == self.script_load_ctr + self.script_load_ctr = (self.script_load_ctr + 1) % len(self.titles) + return gr.update(visible=visibility) + else: + return gr.update(visible=False) + + self.infotext_fields.append( (dropdown, lambda x: gr.update(value=x.get('Script', 'None'))) ) + self.infotext_fields.extend( [(script.group, onload_script_visibility) for script in self.selectable_scripts] ) + return inputs def run(self, p: StableDiffusionProcessing, *args): diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py index 828c2d12..f3378686 100644 --- a/scripts/xyz_grid.py +++ b/scripts/xyz_grid.py @@ -383,6 +383,15 @@ class Script(scripts.Script): y_type.change(fn=select_axis, inputs=[y_type], outputs=[fill_y_button]) z_type.change(fn=select_axis, inputs=[z_type], outputs=[fill_z_button]) + self.infotext_fields = ( + (x_type, "X Type"), + (x_values, "X Values"), + (y_type, "Y Type"), + (y_values, "Y Values"), + (z_type, "Z Type"), + (z_values, "Z Values"), + ) + return [x_type, x_values, y_type, y_values, z_type, z_values, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds] def run(self, p, x_type, x_values, y_type, y_values, z_type, z_values, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds): @@ -541,6 +550,7 @@ class Script(scripts.Script): if grid_infotext[0] is None: pc.extra_generation_params = copy(pc.extra_generation_params) + pc.extra_generation_params['Script'] = self.title() if x_opt.label != 'Nothing': pc.extra_generation_params["X Type"] = x_opt.label From 10421f93c3f7f7ce88cb40391b46d4e6664eff74 Mon Sep 17 00:00:00 2001 From: brkirch Date: Thu, 26 Jan 2023 00:34:38 -0500 Subject: [PATCH 069/151] Fix full previews, --no-half-vae --- modules/processing.py | 8 ++++---- modules/sd_hijack_utils.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index cb41288a..92894d67 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -172,7 +172,7 @@ class StableDiffusionProcessing: midas_in = torch.from_numpy(transformed["midas_in"][None, ...]).to(device=shared.device) midas_in = repeat(midas_in, "1 ... -> n ...", n=self.batch_size) - conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image.to(devices.dtype_unet) if devices.unet_needs_upcast else source_image)) + conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image.to(devices.dtype_vae) if devices.unet_needs_upcast else source_image)) conditioning_image = conditioning_image.float() if devices.unet_needs_upcast else conditioning_image conditioning = torch.nn.functional.interpolate( self.sd_model.depth_model(midas_in), @@ -217,7 +217,7 @@ class StableDiffusionProcessing: ) # Encode the new masked image using first stage of network. - conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image.to(devices.dtype_unet) if devices.unet_needs_upcast else conditioning_image)) + conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image.to(devices.dtype_vae) if devices.unet_needs_upcast else conditioning_image)) # Create the concatenated conditioning tensor to be fed to `c_concat` conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=latent_image.shape[-2:]) @@ -417,7 +417,7 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see def decode_first_stage(model, x): with devices.autocast(disable=x.dtype == devices.dtype_vae): - x = model.decode_first_stage(x) + x = model.decode_first_stage(x.to(devices.dtype_vae) if devices.unet_needs_upcast else x) return x @@ -1001,7 +1001,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): image = torch.from_numpy(batch_images) image = 2. * image - 1. - image = image.to(device=shared.device, dtype=devices.dtype_unet if devices.unet_needs_upcast else None) + image = image.to(device=shared.device, dtype=devices.dtype_vae if devices.unet_needs_upcast else None) self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image)) diff --git a/modules/sd_hijack_utils.py b/modules/sd_hijack_utils.py index f81b169a..f8684475 100644 --- a/modules/sd_hijack_utils.py +++ b/modules/sd_hijack_utils.py @@ -5,7 +5,7 @@ class CondFunc: self = super(CondFunc, cls).__new__(cls) if isinstance(orig_func, str): func_path = orig_func.split('.') - for i in range(len(func_path)-2, -1, -1): + for i in range(len(func_path)-1, -1, -1): try: resolved_obj = importlib.import_module('.'.join(func_path[:i])) break From 1619233a747830887831cfea2f05fe826fce1bed Mon Sep 17 00:00:00 2001 From: Spaceginner Date: Thu, 26 Jan 2023 12:52:44 +0500 Subject: [PATCH 070/151] Only Linux will have max 3.11 --- launch.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/launch.py b/launch.py index e39c68e7..52f3bd52 100644 --- a/launch.py +++ b/launch.py @@ -20,10 +20,10 @@ skip_install = False def check_python_version(): version = sys.version_info version_range = None - if os.name == "nt": - version_range = range(7 + 1, 10 + 1) - else: + if platform.system() == "Linux": version_range = range(7 + 1, 11 + 1) + else: + version_range = range(7 + 1, 10 + 1) try: assert version.major == 3 and version.minor in version_range, "Unsupported Python version, please use Python 3.10.x instead. You can download latest release as of 25th January (3.10.9) from here: https://www.python.org/downloads/release/python-3109/. Please, make sure to first delete current version of Python first and delete `venv` folder inside of WebUI's folder, too." From f4ec411f2c9d6bc6817a2eca8a2c00f255ffb386 Mon Sep 17 00:00:00 2001 From: "ULTRANOX\\Chris" Date: Thu, 26 Jan 2023 03:45:16 -0500 Subject: [PATCH 071/151] Allow checkpoint merger to merge pix2pix models in the same way that it currently supports inpainting models. --- modules/extras.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/modules/extras.py b/modules/extras.py index 36123aa5..67ffdee3 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -132,6 +132,7 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ tertiary_model_info = sd_models.checkpoints_list[tertiary_model_name] if theta_func1 else None result_is_inpainting_model = False + result_is_pix2pix_model = False if theta_func2: shared.state.textinfo = f"Loading B" @@ -186,13 +187,17 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ if a.shape[1] == 4 and b.shape[1] == 9: raise RuntimeError("When merging inpainting model with a normal one, A must be the inpainting model.") - assert a.shape[1] == 9 and b.shape[1] == 4, f"Bad dimensions for merged layer {key}: A={a.shape}, B={b.shape}" - - theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier) - result_is_inpainting_model = True + if a.shape[1] == 8 and b.shape[1] == 4:#If we have an InstructPix2Pix model... + print("Detected possible merge of instruct model with non-instruct model.") + theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier)#Merge only the vectors the models have in common. Otherwise we get an error due to dimension mismatch. + result_is_pix2pix_model = True + else: + assert a.shape[1] == 9 and b.shape[1] == 4, f"Bad dimensions for merged layer {key}: A={a.shape}, B={b.shape}" + theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier) + result_is_inpainting_model = True else: theta_0[key] = theta_func2(a, b, multiplier) - + theta_0[key] = to_half(theta_0[key], save_as_half) shared.state.sampling_step += 1 @@ -226,6 +231,7 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ filename = filename_generator() if custom_name == '' else custom_name filename += ".inpainting" if result_is_inpainting_model else "" + filename += ".pix2pix" if result_is_pix2pix_model else "" filename += "." + checkpoint_format output_modelname = os.path.join(ckpt_dir, filename) From f90798c6b6cc48e514acb08ce02bdb5874bf74d8 Mon Sep 17 00:00:00 2001 From: "ULTRANOX\\Chris" Date: Thu, 26 Jan 2023 04:38:04 -0500 Subject: [PATCH 072/151] Added error check for the rare case a user merges a pix2pix model with a normal model using weighted sum. Also removed bad print message that interfered with merging progress bar. --- modules/extras.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/modules/extras.py b/modules/extras.py index 67ffdee3..badd13c7 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -186,9 +186,10 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ if a.shape != b.shape and a.shape[0:1] + a.shape[2:] == b.shape[0:1] + b.shape[2:]: if a.shape[1] == 4 and b.shape[1] == 9: raise RuntimeError("When merging inpainting model with a normal one, A must be the inpainting model.") + if a.shape[1] == 4 and b.shape[1] == 8: + raise RuntimeError("When merging pix2pix model with a normal one, A must be the pix2pix model.") if a.shape[1] == 8 and b.shape[1] == 4:#If we have an InstructPix2Pix model... - print("Detected possible merge of instruct model with non-instruct model.") theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier)#Merge only the vectors the models have in common. Otherwise we get an error due to dimension mismatch. result_is_pix2pix_model = True else: From 9e72dc743480c8b1ca6aeb8ced3af03f3e3243a3 Mon Sep 17 00:00:00 2001 From: "ULTRANOX\\Chris" Date: Thu, 26 Jan 2023 06:05:40 -0500 Subject: [PATCH 073/151] Changed all references to "pix2pix" to the more precise name "instruct pix2pix". Also changed extension to instrpix2pix at least for now. --- modules/extras.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/modules/extras.py b/modules/extras.py index badd13c7..2bf0d17e 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -132,7 +132,7 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ tertiary_model_info = sd_models.checkpoints_list[tertiary_model_name] if theta_func1 else None result_is_inpainting_model = False - result_is_pix2pix_model = False + result_is_instruct_pix2pix_model = False if theta_func2: shared.state.textinfo = f"Loading B" @@ -187,11 +187,11 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ if a.shape[1] == 4 and b.shape[1] == 9: raise RuntimeError("When merging inpainting model with a normal one, A must be the inpainting model.") if a.shape[1] == 4 and b.shape[1] == 8: - raise RuntimeError("When merging pix2pix model with a normal one, A must be the pix2pix model.") + raise RuntimeError("When merging instruct-pix2pix model with a normal one, A must be the instruct-pix2pix model.") - if a.shape[1] == 8 and b.shape[1] == 4:#If we have an InstructPix2Pix model... + if a.shape[1] == 8 and b.shape[1] == 4:#If we have an Instruct-Pix2Pix model... theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier)#Merge only the vectors the models have in common. Otherwise we get an error due to dimension mismatch. - result_is_pix2pix_model = True + result_is_instruct_pix2pix_model = True else: assert a.shape[1] == 9 and b.shape[1] == 4, f"Bad dimensions for merged layer {key}: A={a.shape}, B={b.shape}" theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier) @@ -232,7 +232,7 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ filename = filename_generator() if custom_name == '' else custom_name filename += ".inpainting" if result_is_inpainting_model else "" - filename += ".pix2pix" if result_is_pix2pix_model else "" + filename += ".instrpix2pix" if result_is_instruct_pix2pix_model else "" filename += "." + checkpoint_format output_modelname = os.path.join(ckpt_dir, filename) From c4b9b07db6272768428fa8efeb7d7a9f22eca0b1 Mon Sep 17 00:00:00 2001 From: brkirch Date: Thu, 26 Jan 2023 09:00:15 -0500 Subject: [PATCH 074/151] Fix embeddings dtype mismatch --- modules/sd_hijack.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index f9652d21..531790f3 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -171,7 +171,7 @@ class EmbeddingsWithFixes(torch.nn.Module): vecs = [] for fixes, tensor in zip(batch_fixes, inputs_embeds): for offset, embedding in fixes: - emb = embedding.vec + emb = embedding.vec.to(devices.dtype_unet) if devices.unet_needs_upcast else embedding.vec emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0]) tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]]) From cdc2fa209a3efdc71a90643a5e7a1df49869cd5f Mon Sep 17 00:00:00 2001 From: "ULTRANOX\\Chris" Date: Thu, 26 Jan 2023 11:27:07 -0500 Subject: [PATCH 075/151] Changed filename addition from "instrpix2pix" to the more readable ".instruct-pix2pix" for newly generated instruct pix2pix models. --- modules/extras.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/extras.py b/modules/extras.py index 2bf0d17e..466ecc15 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -232,7 +232,7 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ filename = filename_generator() if custom_name == '' else custom_name filename += ".inpainting" if result_is_inpainting_model else "" - filename += ".instrpix2pix" if result_is_instruct_pix2pix_model else "" + filename += ".instruct-pix2pix" if result_is_instruct_pix2pix_model else "" filename += "." + checkpoint_format output_modelname = os.path.join(ckpt_dir, filename) From 7a14c8ab45da8a681792a6331d48a88dd684a0a9 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Thu, 26 Jan 2023 23:29:27 +0300 Subject: [PATCH 076/151] add an option to enable sections from extras tab in txt2img/img2img fix some style inconsistenices --- modules/processing.py | 7 ++++- modules/scripts.py | 32 +++++++++++++++++--- modules/scripts_auto_postprocessing.py | 42 ++++++++++++++++++++++++++ modules/scripts_postprocessing.py | 11 +++++-- modules/shared.py | 15 +++------ modules/shared_items.py | 10 ++++++ modules/ui_components.py | 8 +++++ scripts/postprocessing_upscale.py | 25 +++++++++++++++ style.css | 6 +--- 9 files changed, 133 insertions(+), 23 deletions(-) create mode 100644 modules/scripts_auto_postprocessing.py create mode 100644 modules/shared_items.py diff --git a/modules/processing.py b/modules/processing.py index 92894d67..262806a1 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -13,7 +13,7 @@ from skimage import exposure from typing import Any, Dict, List, Optional import modules.sd_hijack -from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks, extra_networks, sd_vae_approx +from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks, extra_networks, sd_vae_approx, scripts from modules.sd_hijack import model_hijack from modules.shared import opts, cmd_opts, state import modules.shared as shared @@ -658,6 +658,11 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: image = Image.fromarray(x_sample) + if p.scripts is not None: + pp = scripts.PostprocessImageArgs(image) + p.scripts.postprocess_image(p, pp) + image = pp.image + if p.color_corrections is not None and i < len(p.color_corrections): if opts.save and not p.do_not_save_samples and opts.save_images_before_color_correction: image_without_cc = apply_overlay(image, p.paste_to, i, p.overlay_images) diff --git a/modules/scripts.py b/modules/scripts.py index 03907a63..6e9dc0c0 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -6,12 +6,16 @@ from collections import namedtuple import gradio as gr -from modules.processing import StableDiffusionProcessing from modules import shared, paths, script_callbacks, extensions, script_loading, scripts_postprocessing AlwaysVisible = object() +class PostprocessImageArgs: + def __init__(self, image): + self.image = image + + class Script: filename = None args_from = None @@ -65,7 +69,7 @@ class Script: args contains all values returned by components from ui() """ - raise NotImplementedError() + pass def process(self, p, *args): """ @@ -100,6 +104,13 @@ class Script: pass + def postprocess_image(self, p, pp: PostprocessImageArgs, *args): + """ + Called for every image after it has been generated. + """ + + pass + def postprocess(self, p, processed, *args): """ This function is called after processing ends for AlwaysVisible scripts. @@ -247,11 +258,15 @@ class ScriptRunner: self.infotext_fields = [] def initialize_scripts(self, is_img2img): + from modules import scripts_auto_postprocessing + self.scripts.clear() self.alwayson_scripts.clear() self.selectable_scripts.clear() - for script_class, path, basedir, script_module in scripts_data: + auto_processing_scripts = scripts_auto_postprocessing.create_auto_preprocessing_script_data() + + for script_class, path, basedir, script_module in auto_processing_scripts + scripts_data: script = script_class() script.filename = path script.is_txt2img = not is_img2img @@ -332,7 +347,7 @@ class ScriptRunner: return inputs - def run(self, p: StableDiffusionProcessing, *args): + def run(self, p, *args): script_index = args[0] if script_index == 0: @@ -386,6 +401,15 @@ class ScriptRunner: print(f"Error running postprocess_batch: {script.filename}", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) + def postprocess_image(self, p, pp: PostprocessImageArgs): + for script in self.alwayson_scripts: + try: + script_args = p.script_args[script.args_from:script.args_to] + script.postprocess_image(p, pp, *script_args) + except Exception: + print(f"Error running postprocess_batch: {script.filename}", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + def before_component(self, component, **kwargs): for script in self.scripts: try: diff --git a/modules/scripts_auto_postprocessing.py b/modules/scripts_auto_postprocessing.py new file mode 100644 index 00000000..30d6d658 --- /dev/null +++ b/modules/scripts_auto_postprocessing.py @@ -0,0 +1,42 @@ +from modules import scripts, scripts_postprocessing, shared + + +class ScriptPostprocessingForMainUI(scripts.Script): + def __init__(self, script_postproc): + self.script: scripts_postprocessing.ScriptPostprocessing = script_postproc + self.postprocessing_controls = None + + def title(self): + return self.script.name + + def show(self, is_img2img): + return scripts.AlwaysVisible + + def ui(self, is_img2img): + self.postprocessing_controls = self.script.ui() + return self.postprocessing_controls.values() + + def postprocess_image(self, p, script_pp, *args): + args_dict = {k: v for k, v in zip(self.postprocessing_controls, args)} + + pp = scripts_postprocessing.PostprocessedImage(script_pp.image) + pp.info = {} + self.script.process(pp, **args_dict) + p.extra_generation_params.update(pp.info) + script_pp.image = pp.image + + +def create_auto_preprocessing_script_data(): + from modules import scripts + + res = [] + + for name in shared.opts.postprocessing_enable_in_main_ui: + script = next(iter([x for x in scripts.postprocessing_scripts_data if x.script_class.name == name]), None) + if script is None: + continue + + constructor = lambda s=script: ScriptPostprocessingForMainUI(s.script_class()) + res.append(scripts.ScriptClassData(script_class=constructor, path=script.path, basedir=script.basedir, module=script.module)) + + return res diff --git a/modules/scripts_postprocessing.py b/modules/scripts_postprocessing.py index 25de02d0..ce0ebb61 100644 --- a/modules/scripts_postprocessing.py +++ b/modules/scripts_postprocessing.py @@ -46,6 +46,8 @@ class ScriptPostprocessing: pass + + def wrap_call(func, filename, funcname, *args, default=None, **kwargs): try: res = func(*args, **kwargs) @@ -68,6 +70,9 @@ class ScriptPostprocessingRunner: script: ScriptPostprocessing = script_class() script.filename = path + if script.name == "Simple Upscale": + continue + self.scripts.append(script) def create_script_ui(self, script, inputs): @@ -87,12 +92,11 @@ class ScriptPostprocessingRunner: import modules.scripts self.initialize_scripts(modules.scripts.postprocessing_scripts_data) - scripts_order = [x.lower().strip() for x in shared.opts.postprocessing_scipts_order.split(",")] + scripts_order = shared.opts.postprocessing_operation_order def script_score(name): - name = name.lower() for i, possible_match in enumerate(scripts_order): - if possible_match in name: + if possible_match == name: return i return len(self.scripts) @@ -145,3 +149,4 @@ class ScriptPostprocessingRunner: def image_changed(self): for script in self.scripts_in_preferred_order(): script.image_changed() + diff --git a/modules/shared.py b/modules/shared.py index 6a0b96cb..cdeed55d 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -13,8 +13,8 @@ import modules.interrogate import modules.memmon import modules.styles import modules.devices as devices -from modules import localization, sd_vae, extensions, script_loading, errors, ui_components -from modules.paths import models_path, script_path, sd_path +from modules import localization, sd_vae, extensions, script_loading, errors, ui_components, shared_items +from modules.paths import models_path, script_path demo = None @@ -264,12 +264,6 @@ interrogator = modules.interrogate.InterrogateModels("interrogate") face_restorers = [] - -def realesrgan_models_names(): - import modules.realesrgan_model - return [x.name for x in modules.realesrgan_model.get_realesrgan_models(None)] - - class OptionInfo: def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, section=None, refresh=None): self.default = default @@ -360,7 +354,7 @@ options_templates.update(options_section(('saving-to-dirs', "Saving to a directo options_templates.update(options_section(('upscaling', "Upscaling"), { "ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers. 0 = no tiling.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}), "ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for ESRGAN upscalers. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}), - "realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": realesrgan_models_names()}), + "realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": shared_items.realesrgan_models_names()}), "upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers]}), })) @@ -483,7 +477,8 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters" })) options_templates.update(options_section(('postprocessing', "Postprocessing"), { - 'postprocessing_scipts_order': OptionInfo("upscale, gfpgan, codeformer", "Postprocessing operation order"), + 'postprocessing_enable_in_main_ui': OptionInfo([], "Enable postprocessing operations in txt2img and img2img tabs", ui_components.DropdownMulti, lambda: {"choices": [x.name for x in shared_items.postprocessing_scripts()]}), + 'postprocessing_operation_order': OptionInfo([], "Postprocessing operation order", ui_components.DropdownMulti, lambda: {"choices": [x.name for x in shared_items.postprocessing_scripts()]}), 'upscaling_max_images_in_cache': OptionInfo(5, "Maximum number of images in upscaling cache", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), })) diff --git a/modules/shared_items.py b/modules/shared_items.py new file mode 100644 index 00000000..b5d480c9 --- /dev/null +++ b/modules/shared_items.py @@ -0,0 +1,10 @@ + + +def realesrgan_models_names(): + import modules.realesrgan_model + return [x.name for x in modules.realesrgan_model.get_realesrgan_models(None)] + +def postprocessing_scripts(): + import modules.scripts + + return modules.scripts.scripts_postproc.scripts \ No newline at end of file diff --git a/modules/ui_components.py b/modules/ui_components.py index 9aec3097..284ca0cf 100644 --- a/modules/ui_components.py +++ b/modules/ui_components.py @@ -48,3 +48,11 @@ class FormColorPicker(gr.ColorPicker, gr.components.FormComponent): def get_block_name(self): return "colorpicker" + +class DropdownMulti(gr.Dropdown): + """Same as gr.Dropdown but always multiselect""" + def __init__(self, **kwargs): + super().__init__(multiselect=True, **kwargs) + + def get_block_name(self): + return "dropdown" diff --git a/scripts/postprocessing_upscale.py b/scripts/postprocessing_upscale.py index 095d29b2..8842bd91 100644 --- a/scripts/postprocessing_upscale.py +++ b/scripts/postprocessing_upscale.py @@ -104,3 +104,28 @@ class ScriptPostprocessingUpscale(scripts_postprocessing.ScriptPostprocessing): def image_changed(self): upscale_cache.clear() + + +class ScriptPostprocessingUpscaleSimple(ScriptPostprocessingUpscale): + name = "Simple Upscale" + order = 900 + + def ui(self): + with FormRow(): + upscaler_name = gr.Dropdown(label='Upscaler', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name) + upscale_by = gr.Slider(minimum=0.05, maximum=8.0, step=0.05, label="Upscale by", value=2) + + return { + "upscale_by": upscale_by, + "upscaler_name": upscaler_name, + } + + def process(self, pp: scripts_postprocessing.PostprocessedImage, upscale_by=2.0, upscaler_name=None): + if upscaler_name is None or upscaler_name == "None": + return + + upscaler1 = next(iter([x for x in shared.sd_upscalers if x.name == upscaler_name]), None) + assert upscaler1, f'could not find upscaler named {upscaler_name}' + + pp.image = self.upscale(pp.image, pp.info, upscaler1, 0, upscale_by, 0, 0, False) + pp.info[f"Postprocess upscaler"] = upscaler1.name diff --git a/style.css b/style.css index ec046f78..dd914104 100644 --- a/style.css +++ b/style.css @@ -164,7 +164,7 @@ min-height: 3.2em; } -#txt2img_styles ul, #img2img_styles ul{ +ul.list-none{ max-height: 35em; z-index: 2000; } @@ -714,9 +714,6 @@ footer { white-space: nowrap; min-width: auto; } -#txt2img_hires_fix{ - margin-left: -0.8em; -} #img2img_copy_to_img2img, #img2img_copy_to_sketch, #img2img_copy_to_inpaint, #img2img_copy_to_inpaint_sketch{ margin-left: 0em; @@ -744,7 +741,6 @@ footer { .dark .gr-compact{ background-color: rgb(31 41 55 / var(--tw-bg-opacity)); - margin-left: 0.8em; } .gr-compact{ From a43fafb481feb3ef369d1963412de4e7b320fc34 Mon Sep 17 00:00:00 2001 From: ItsOlegDm Date: Thu, 26 Jan 2023 23:25:48 +0200 Subject: [PATCH 077/151] css fixes --- style.css | 45 ++++++++++++++++++++++++++++++++++++++------- 1 file changed, 38 insertions(+), 7 deletions(-) diff --git a/style.css b/style.css index dd914104..1e90b105 100644 --- a/style.css +++ b/style.css @@ -74,7 +74,12 @@ #txt2img_gallery img, #img2img_gallery img{ object-fit: scale-down; } - +#txt2img_actions_column, #img2img_actions_column { + margin: 0.35rem 0.75rem 0.35rem 0; +} +#script_list { + padding: .625rem .75rem 0 .625rem; +} .justify-center.overflow-x-scroll { justify-content: left; } @@ -126,10 +131,12 @@ #txt2img_actions_column, #img2img_actions_column{ gap: 0; + margin-right: .75rem; } #txt2img_tools, #img2img_tools{ gap: 0.4em; + justify-content: center; } #interrogate_col{ @@ -155,7 +162,9 @@ #txt2img_styles_row > button, #img2img_styles_row > button{ margin: 0; } - +#txt2img_styles_row { + margin-top: 0.3em; +} #txt2img_styles, #img2img_styles{ padding: 0; } @@ -311,11 +320,11 @@ input[type="range"]{ .min-h-\[6rem\] { min-height: unset !important; } .progressDiv{ - position: absolute; + position: relative; height: 20px; - top: -20px; background: #b4c0cc; border-radius: 3px !important; + margin-bottom: -3px; } .dark .progressDiv{ @@ -535,7 +544,7 @@ input[type="range"]{ } #quicksettings { - gap: 0.4em; + width: fit-content; } #quicksettings > div, #quicksettings > fieldset{ @@ -545,6 +554,7 @@ input[type="range"]{ border: none; box-shadow: none; background: none; + margin-right: 10px; } #quicksettings > div > div > div > label > span { @@ -567,7 +577,7 @@ canvas[key="mask"] { right: 0.5em; top: -0.6em; z-index: 400; - width: 8em; + width: 6em; } #quicksettings .gr-box > div > div > input.gr-text-input { top: -1.12em; @@ -665,11 +675,27 @@ canvas[key="mask"] { #quicksettings .gr-button-tool{ margin: 0; + border-color: unset; + background-color: unset; } - +#modelmerger_interp_description>p { + margin: 0!important; + text-align: center; +} +#modelmerger_interp_description { + margin: 0.35rem 0.75rem 1.23rem; +} #img2img_settings > div.gr-form, #txt2img_settings > div.gr-form { padding-top: 0.9em; + padding-bottom: 0.9em; +} +#txt2img_settings { + padding-top: 1.16em; + padding-bottom: 0.9em; +} +#img2img_settings { + padding-bottom: 0.9em; } #img2img_settings div.gr-form .gr-form, #txt2img_settings div.gr-form .gr-form, #train_tabs div.gr-form .gr-form{ @@ -741,6 +767,8 @@ footer { .dark .gr-compact{ background-color: rgb(31 41 55 / var(--tw-bg-opacity)); + align-items: center; + margin-left: 0; } .gr-compact{ @@ -925,3 +953,6 @@ footer { color: red; } +[id*='_prompt_container'] > div { + margin: 0!important; +} From d2ac95fa7b2a8d0bcc5361ee16dba9cbb81ff8b2 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 27 Jan 2023 11:28:12 +0300 Subject: [PATCH 078/151] remove the need to place configs near models --- configs/instruct-pix2pix.yaml | 99 ++++++++ .../v1-inpainting-inference.yaml | 36 +-- modules/api/api.py | 5 +- modules/devices.py | 12 +- modules/sd_hijack_inpainting.py | 9 - modules/sd_models.py | 230 +++++++++--------- modules/sd_models_config.py | 65 +++++ modules/shared.py | 7 +- modules/shared_items.py | 15 +- modules/timer.py | 35 +++ 10 files changed, 361 insertions(+), 152 deletions(-) create mode 100644 configs/instruct-pix2pix.yaml rename v2-inference-v.yaml => configs/v1-inpainting-inference.yaml (61%) create mode 100644 modules/sd_models_config.py create mode 100644 modules/timer.py diff --git a/configs/instruct-pix2pix.yaml b/configs/instruct-pix2pix.yaml new file mode 100644 index 00000000..437ddcef --- /dev/null +++ b/configs/instruct-pix2pix.yaml @@ -0,0 +1,99 @@ +# File modified by authors of InstructPix2Pix from original (https://github.com/CompVis/stable-diffusion). +# See more details in LICENSE. + +model: + base_learning_rate: 1.0e-04 + target: modules.models.diffusion.ddpm_edit.LatentDiffusion + params: + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: edited + cond_stage_key: edit + # image_size: 64 + # image_size: 32 + image_size: 16 + channels: 4 + cond_stage_trainable: false # Note: different from the one we trained before + conditioning_key: hybrid + monitor: val/loss_simple_ema + scale_factor: 0.18215 + use_ema: true + load_ema: true + + scheduler_config: # 10000 warmup steps + target: ldm.lr_scheduler.LambdaLinearScheduler + params: + warm_up_steps: [ 0 ] + cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases + f_start: [ 1.e-6 ] + f_max: [ 1. ] + f_min: [ 1. ] + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + image_size: 32 # unused + in_channels: 8 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_heads: 8 + use_spatial_transformer: True + transformer_depth: 1 + context_dim: 768 + use_checkpoint: True + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.FrozenCLIPEmbedder + +data: + target: main.DataModuleFromConfig + params: + batch_size: 128 + num_workers: 1 + wrap: false + validation: + target: edit_dataset.EditDataset + params: + path: data/clip-filtered-dataset + cache_dir: data/ + cache_name: data_10k + split: val + min_text_sim: 0.2 + min_image_sim: 0.75 + min_direction_sim: 0.2 + max_samples_per_prompt: 1 + min_resize_res: 512 + max_resize_res: 512 + crop_res: 512 + output_as_edit: False + real_input: True diff --git a/v2-inference-v.yaml b/configs/v1-inpainting-inference.yaml similarity index 61% rename from v2-inference-v.yaml rename to configs/v1-inpainting-inference.yaml index 513cd635..f9eec37d 100644 --- a/v2-inference-v.yaml +++ b/configs/v1-inpainting-inference.yaml @@ -1,8 +1,7 @@ model: - base_learning_rate: 1.0e-4 - target: ldm.models.diffusion.ddpm.LatentDiffusion + base_learning_rate: 7.5e-05 + target: ldm.models.diffusion.ddpm.LatentInpaintDiffusion params: - parameterization: "v" linear_start: 0.00085 linear_end: 0.0120 num_timesteps_cond: 1 @@ -12,29 +11,36 @@ model: cond_stage_key: "txt" image_size: 64 channels: 4 - cond_stage_trainable: false - conditioning_key: crossattn + cond_stage_trainable: false # Note: different from the one we trained before + conditioning_key: hybrid # important monitor: val/loss_simple_ema scale_factor: 0.18215 - use_ema: False # we set this to false because this is an inference only config + finetune_keys: null + + scheduler_config: # 10000 warmup steps + target: ldm.lr_scheduler.LambdaLinearScheduler + params: + warm_up_steps: [ 2500 ] # NOTE for resuming. use 10000 if starting from scratch + cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases + f_start: [ 1.e-6 ] + f_max: [ 1. ] + f_min: [ 1. ] unet_config: target: ldm.modules.diffusionmodules.openaimodel.UNetModel params: - use_checkpoint: True - use_fp16: True image_size: 32 # unused - in_channels: 4 + in_channels: 9 # 4 data + 4 downscaled image + 1 mask out_channels: 4 model_channels: 320 attention_resolutions: [ 4, 2, 1 ] num_res_blocks: 2 channel_mult: [ 1, 2, 4, 4 ] - num_head_channels: 64 # need to fix for flash-attn + num_heads: 8 use_spatial_transformer: True - use_linear_in_transformer: True transformer_depth: 1 - context_dim: 1024 + context_dim: 768 + use_checkpoint: True legacy: False first_stage_config: @@ -43,7 +49,6 @@ model: embed_dim: 4 monitor: val/rec_loss ddconfig: - #attn_type: "vanilla-xformers" double_z: true z_channels: 4 resolution: 256 @@ -62,7 +67,4 @@ model: target: torch.nn.Identity cond_stage_config: - target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder - params: - freeze: True - layer: "penultimate" \ No newline at end of file + target: ldm.modules.encoders.modules.FrozenCLIPEmbedder diff --git a/modules/api/api.py b/modules/api/api.py index 25c65e57..eb7b1da5 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -18,7 +18,8 @@ from modules.textual_inversion.textual_inversion import create_embedding, train_ from modules.textual_inversion.preprocess import preprocess from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork from PIL import PngImagePlugin,Image -from modules.sd_models import checkpoints_list, find_checkpoint_config +from modules.sd_models import checkpoints_list +from modules.sd_models_config import find_checkpoint_config_near_filename from modules.realesrgan_model import get_realesrgan_models from modules import devices from typing import List @@ -387,7 +388,7 @@ class Api: ] def get_sd_models(self): - return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config(x)} for x in checkpoints_list.values()] + return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config_near_filename(x)} for x in checkpoints_list.values()] def get_hypernetworks(self): return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks] diff --git a/modules/devices.py b/modules/devices.py index 6b36622c..2d5f797a 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -34,14 +34,18 @@ def get_cuda_device_string(): return "cuda" -def get_optimal_device(): +def get_optimal_device_name(): if torch.cuda.is_available(): - return torch.device(get_cuda_device_string()) + return get_cuda_device_string() if has_mps(): - return torch.device("mps") + return "mps" - return cpu + return "cpu" + + +def get_optimal_device(): + return torch.device(get_optimal_device_name()) def get_device_for(task): diff --git a/modules/sd_hijack_inpainting.py b/modules/sd_hijack_inpainting.py index 31d2c898..478cd499 100644 --- a/modules/sd_hijack_inpainting.py +++ b/modules/sd_hijack_inpainting.py @@ -96,15 +96,6 @@ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=F return x_prev, pred_x0, e_t -def should_hijack_inpainting(checkpoint_info): - from modules import sd_models - - ckpt_basename = os.path.basename(checkpoint_info.filename).lower() - cfg_basename = os.path.basename(sd_models.find_checkpoint_config(checkpoint_info)).lower() - - return "inpainting" in ckpt_basename and not "inpainting" in cfg_basename - - def do_inpainting_hijack(): # p_sample_plms is needed because PLMS can't work with dicts as conditionings diff --git a/modules/sd_models.py b/modules/sd_models.py index 7072eb2e..fa208728 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -2,8 +2,6 @@ import collections import os.path import sys import gc -import time -from collections import namedtuple import torch import re import safetensors.torch @@ -14,10 +12,10 @@ import ldm.modules.midas as midas from ldm.util import instantiate_from_config -from modules import shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes +from modules import shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config from modules.paths import models_path -from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting -from modules.sd_hijack_ip2p import should_hijack_ip2p +from modules.sd_hijack_inpainting import do_inpainting_hijack +from modules.timer import Timer model_dir = "Stable-diffusion" model_path = os.path.abspath(os.path.join(models_path, model_dir)) @@ -99,17 +97,6 @@ def checkpoint_tiles(): return sorted([x.title for x in checkpoints_list.values()], key=alphanumeric_key) -def find_checkpoint_config(info): - if info is None: - return shared.cmd_opts.config - - config = os.path.splitext(info.filename)[0] + ".yaml" - if os.path.exists(config): - return config - - return shared.cmd_opts.config - - def list_models(): checkpoints_list.clear() checkpoint_alisases.clear() @@ -215,9 +202,7 @@ def get_state_dict_from_checkpoint(pl_sd): def read_state_dict(checkpoint_file, print_global_state=False, map_location=None): _, extension = os.path.splitext(checkpoint_file) if extension.lower() == ".safetensors": - device = map_location or shared.weight_load_location - if device is None: - device = devices.get_cuda_device_string() if torch.cuda.is_available() else "cpu" + device = map_location or shared.weight_load_location or devices.get_optimal_device_name() pl_sd = safetensors.torch.load_file(checkpoint_file, device=device) else: pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location) @@ -229,60 +214,74 @@ def read_state_dict(checkpoint_file, print_global_state=False, map_location=None return sd -def load_model_weights(model, checkpoint_info: CheckpointInfo): +def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer): + sd_model_hash = checkpoint_info.calculate_shorthash() + timer.record("calculate hash") + + if checkpoint_info in checkpoints_loaded: + # use checkpoint cache + print(f"Loading weights [{sd_model_hash}] from cache") + return checkpoints_loaded[checkpoint_info] + + print(f"Loading weights [{sd_model_hash}] from {checkpoint_info.filename}") + res = read_state_dict(checkpoint_info.filename) + timer.record("load weights from disk") + + return res + + +def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer): title = checkpoint_info.title sd_model_hash = checkpoint_info.calculate_shorthash() + timer.record("calculate hash") + if checkpoint_info.title != title: shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title - cache_enabled = shared.opts.sd_checkpoint_cache > 0 + if state_dict is None: + state_dict = get_checkpoint_state_dict(checkpoint_info, timer) - if cache_enabled and checkpoint_info in checkpoints_loaded: - # use checkpoint cache - print(f"Loading weights [{sd_model_hash}] from cache") - model.load_state_dict(checkpoints_loaded[checkpoint_info]) - else: - # load from file - print(f"Loading weights [{sd_model_hash}] from {checkpoint_info.filename}") + model.load_state_dict(state_dict, strict=False) + del state_dict + timer.record("apply weights to model") - sd = read_state_dict(checkpoint_info.filename) - model.load_state_dict(sd, strict=False) - del sd - - if cache_enabled: - # cache newly loaded model - checkpoints_loaded[checkpoint_info] = model.state_dict().copy() + if shared.opts.sd_checkpoint_cache > 0: + # cache newly loaded model + checkpoints_loaded[checkpoint_info] = model.state_dict().copy() - if shared.cmd_opts.opt_channelslast: - model.to(memory_format=torch.channels_last) + if shared.cmd_opts.opt_channelslast: + model.to(memory_format=torch.channels_last) + timer.record("apply channels_last") - if not shared.cmd_opts.no_half: - vae = model.first_stage_model - depth_model = getattr(model, 'depth_model', None) + if not shared.cmd_opts.no_half: + vae = model.first_stage_model + depth_model = getattr(model, 'depth_model', None) - # with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16 - if shared.cmd_opts.no_half_vae: - model.first_stage_model = None - # with --upcast-sampling, don't convert the depth model weights to float16 - if shared.cmd_opts.upcast_sampling and depth_model: - model.depth_model = None + # with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16 + if shared.cmd_opts.no_half_vae: + model.first_stage_model = None + # with --upcast-sampling, don't convert the depth model weights to float16 + if shared.cmd_opts.upcast_sampling and depth_model: + model.depth_model = None - model.half() - model.first_stage_model = vae - if depth_model: - model.depth_model = depth_model + model.half() + model.first_stage_model = vae + if depth_model: + model.depth_model = depth_model - devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16 - devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16 - devices.dtype_unet = model.model.diffusion_model.dtype - devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16 + timer.record("apply half()") - model.first_stage_model.to(devices.dtype_vae) + devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16 + devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16 + devices.dtype_unet = model.model.diffusion_model.dtype + devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16 + + model.first_stage_model.to(devices.dtype_vae) + timer.record("apply dtype to VAE") # clean up cache if limit is reached - if cache_enabled: - while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache + 1: # we need to count the current model - checkpoints_loaded.popitem(last=False) # LRU + while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache: + checkpoints_loaded.popitem(last=False) model.sd_model_hash = sd_model_hash model.sd_model_checkpoint = checkpoint_info.filename @@ -295,6 +294,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo): sd_vae.clear_loaded_vae() vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename) sd_vae.load_vae(model, vae_file, vae_source) + timer.record("load VAE") def enable_midas_autodownload(): @@ -340,24 +340,20 @@ def enable_midas_autodownload(): midas.api.load_model = load_model_wrapper -class Timer: - def __init__(self): - self.start = time.time() +def repair_config(sd_config): - def elapsed(self): - end = time.time() - res = end - self.start - self.start = end - return res + if not hasattr(sd_config.model.params, "use_ema"): + sd_config.model.params.use_ema = False + + if shared.cmd_opts.no_half: + sd_config.model.params.unet_config.params.use_fp16 = False + elif shared.cmd_opts.upcast_sampling: + sd_config.model.params.unet_config.params.use_fp16 = True -def load_model(checkpoint_info=None): +def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_to_load_state_dict=None): from modules import lowvram, sd_hijack checkpoint_info = checkpoint_info or select_checkpoint() - checkpoint_config = find_checkpoint_config(checkpoint_info) - - if checkpoint_config != shared.cmd_opts.config: - print(f"Loading config from: {checkpoint_config}") if shared.sd_model: sd_hijack.model_hijack.undo_hijack(shared.sd_model) @@ -365,38 +361,27 @@ def load_model(checkpoint_info=None): gc.collect() devices.torch_gc() - sd_config = OmegaConf.load(checkpoint_config) - - if should_hijack_inpainting(checkpoint_info): - # Hardcoded config for now... - sd_config.model.target = "ldm.models.diffusion.ddpm.LatentInpaintDiffusion" - sd_config.model.params.conditioning_key = "hybrid" - sd_config.model.params.unet_config.params.in_channels = 9 - sd_config.model.params.finetune_keys = None - - if should_hijack_ip2p(checkpoint_info): - sd_config.model.target = "modules.models.diffusion.ddpm_edit.LatentDiffusion" - sd_config.model.params.conditioning_key = "hybrid" - sd_config.model.params.first_stage_key = "edited" - sd_config.model.params.cond_stage_key = "edit" - sd_config.model.params.image_size = 16 - sd_config.model.params.unet_config.params.in_channels = 8 - sd_config.model.params.unet_config.params.out_channels = 4 - - if not hasattr(sd_config.model.params, "use_ema"): - sd_config.model.params.use_ema = False - do_inpainting_hijack() - if shared.cmd_opts.no_half: - sd_config.model.params.unet_config.params.use_fp16 = False - elif shared.cmd_opts.upcast_sampling: - sd_config.model.params.unet_config.params.use_fp16 = True - timer = Timer() - sd_model = None + if already_loaded_state_dict is not None: + state_dict = already_loaded_state_dict + else: + state_dict = get_checkpoint_state_dict(checkpoint_info, timer) + checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info) + + timer.record("find config") + + sd_config = OmegaConf.load(checkpoint_config) + repair_config(sd_config) + + timer.record("load config") + + print(f"Creating model from config: {checkpoint_config}") + + sd_model = None try: with sd_disable_initialization.DisableInitialization(): sd_model = instantiate_from_config(sd_config.model) @@ -407,29 +392,35 @@ def load_model(checkpoint_info=None): print('Failed to create model quickly; will retry using slow method.', file=sys.stderr) sd_model = instantiate_from_config(sd_config.model) - elapsed_create = timer.elapsed() + sd_model.used_config = checkpoint_config - load_model_weights(sd_model, checkpoint_info) + timer.record("create model") - elapsed_load_weights = timer.elapsed() + load_model_weights(sd_model, checkpoint_info, state_dict, timer) if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram) else: sd_model.to(shared.device) + timer.record("move model to device") + sd_hijack.model_hijack.hijack(sd_model) + timer.record("hijack") + sd_model.eval() shared.sd_model = sd_model sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model + timer.record("load textual inversion embeddings") + script_callbacks.model_loaded_callback(sd_model) - elapsed_the_rest = timer.elapsed() + timer.record("scripts callbacks") - print(f"Model loaded in {elapsed_create + elapsed_load_weights + elapsed_the_rest:.1f}s ({elapsed_create:.1f}s create model, {elapsed_load_weights:.1f}s load weights).") + print(f"Model loaded in {timer.summary()}.") return sd_model @@ -440,6 +431,7 @@ def reload_model_weights(sd_model=None, info=None): if not sd_model: sd_model = shared.sd_model + if sd_model is None: # previous model load failed current_checkpoint_info = None else: @@ -447,14 +439,6 @@ def reload_model_weights(sd_model=None, info=None): if sd_model.sd_model_checkpoint == checkpoint_info.filename: return - checkpoint_config = find_checkpoint_config(current_checkpoint_info) - - if current_checkpoint_info is None or checkpoint_config != find_checkpoint_config(checkpoint_info) or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info) or should_hijack_ip2p(checkpoint_info) != should_hijack_ip2p(sd_model.sd_checkpoint_info): - del sd_model - checkpoints_loaded.clear() - load_model(checkpoint_info) - return shared.sd_model - if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: lowvram.send_everything_to_cpu() else: @@ -464,21 +448,35 @@ def reload_model_weights(sd_model=None, info=None): timer = Timer() + state_dict = get_checkpoint_state_dict(checkpoint_info, timer) + + checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info) + + timer.record("find config") + + if sd_model is None or checkpoint_config != sd_model.used_config: + del sd_model + checkpoints_loaded.clear() + load_model(checkpoint_info, already_loaded_state_dict=state_dict, time_taken_to_load_state_dict=timer.records["load weights from disk"]) + return shared.sd_model + try: - load_model_weights(sd_model, checkpoint_info) + load_model_weights(sd_model, checkpoint_info, state_dict, timer) except Exception as e: print("Failed to load checkpoint, restoring previous") - load_model_weights(sd_model, current_checkpoint_info) + load_model_weights(sd_model, current_checkpoint_info, None, timer) raise finally: sd_hijack.model_hijack.hijack(sd_model) + timer.record("hijack") + script_callbacks.model_loaded_callback(sd_model) + timer.record("script callbacks") if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram: sd_model.to(devices.device) + timer.record("move model to device") - elapsed = timer.elapsed() - - print(f"Weights loaded in {elapsed:.1f}s.") + print(f"Weights loaded in {timer.summary()}.") return sd_model diff --git a/modules/sd_models_config.py b/modules/sd_models_config.py new file mode 100644 index 00000000..ea773a10 --- /dev/null +++ b/modules/sd_models_config.py @@ -0,0 +1,65 @@ +import re +import os + +from modules import shared, paths + +sd_configs_path = shared.sd_configs_path +sd_repo_configs_path = os.path.join(paths.paths['Stable Diffusion'], "configs", "stable-diffusion") + + +config_default = shared.sd_default_config +config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference.yaml") +config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml") +config_inpainting = os.path.join(sd_configs_path, "v1-inpainting-inference.yaml") +config_instruct_pix2pix = os.path.join(sd_configs_path, "instruct-pix2pix.yaml") +config_alt_diffusion = os.path.join(sd_configs_path, "alt-diffusion-inference.yaml") + +re_parametrization_v = re.compile(r'-v\b') + + +def guess_model_config_from_state_dict(sd, filename): + fn = os.path.basename(filename) + + sd2_cond_proj_weight = sd.get('cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight', None) + diffusion_model_input = sd.get('model.diffusion_model.input_blocks.0.0.weight', None) + roberta_weight = sd.get('cond_stage_model.roberta.embeddings.word_embeddings.weight', None) + + if sd2_cond_proj_weight is not None and sd2_cond_proj_weight.shape[1] == 1024: + if re.search(re_parametrization_v, fn) or "v2-1_768" in fn: + return config_sd2v + else: + return config_sd2 + + if diffusion_model_input is not None: + if diffusion_model_input.shape[1] == 9: + return config_inpainting + if diffusion_model_input.shape[1] == 8: + return config_instruct_pix2pix + + if roberta_weight is not None: + return config_alt_diffusion + + return config_default + + +def find_checkpoint_config(state_dict, info): + if info is None: + return guess_model_config_from_state_dict(state_dict, "") + + config = find_checkpoint_config_near_filename(info) + if config is not None: + return config + + return guess_model_config_from_state_dict(state_dict, info.filename) + + +def find_checkpoint_config_near_filename(info): + if info is None: + return None + + config = os.path.splitext(info.filename)[0] + ".yaml" + if os.path.exists(config): + return config + + return None + diff --git a/modules/shared.py b/modules/shared.py index cdeed55d..14be993d 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -13,13 +13,14 @@ import modules.interrogate import modules.memmon import modules.styles import modules.devices as devices -from modules import localization, sd_vae, extensions, script_loading, errors, ui_components, shared_items +from modules import localization, extensions, script_loading, errors, ui_components, shared_items from modules.paths import models_path, script_path demo = None -sd_default_config = os.path.join(script_path, "configs/v1-inference.yaml") +sd_configs_path = os.path.join(script_path, "configs") +sd_default_config = os.path.join(sd_configs_path, "v1-inference.yaml") sd_model_file = os.path.join(script_path, 'model.ckpt') default_sd_model_file = sd_model_file @@ -391,7 +392,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints), "sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), "sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), - "sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": ["Automatic", "None"] + list(sd_vae.vae_dict)}, refresh=sd_vae.refresh_vae_list), + "sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list), "sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"), "inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), "initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01}), diff --git a/modules/shared_items.py b/modules/shared_items.py index b5d480c9..8b5ec96d 100644 --- a/modules/shared_items.py +++ b/modules/shared_items.py @@ -4,7 +4,20 @@ def realesrgan_models_names(): import modules.realesrgan_model return [x.name for x in modules.realesrgan_model.get_realesrgan_models(None)] + def postprocessing_scripts(): import modules.scripts - return modules.scripts.scripts_postproc.scripts \ No newline at end of file + return modules.scripts.scripts_postproc.scripts + + +def sd_vae_items(): + import modules.sd_vae + + return ["Automatic", "None"] + list(modules.sd_vae.vae_dict) + + +def refresh_vae_list(): + import modules.sd_vae + + return modules.sd_vae.refresh_vae_list diff --git a/modules/timer.py b/modules/timer.py new file mode 100644 index 00000000..57a4f17a --- /dev/null +++ b/modules/timer.py @@ -0,0 +1,35 @@ +import time + + +class Timer: + def __init__(self): + self.start = time.time() + self.records = {} + self.total = 0 + + def elapsed(self): + end = time.time() + res = end - self.start + self.start = end + return res + + def record(self, category, extra_time=0): + e = self.elapsed() + if category not in self.records: + self.records[category] = 0 + + self.records[category] += e + extra_time + self.total += e + extra_time + + def summary(self): + res = f"{self.total:.1f}s" + + additions = [x for x in self.records.items() if x[1] >= 0.1] + if not additions: + return res + + res += " (" + res += ", ".join([f"{category}: {time_taken:.1f}s" for category, time_taken in additions]) + res += ")" + + return res From 6f31d2210c189f8db118e6f95add7ba2a64f0238 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 27 Jan 2023 11:54:19 +0300 Subject: [PATCH 079/151] support detecting midas model fix broken api for checkpoint list --- modules/api/models.py | 2 +- modules/sd_models.py | 10 +++++----- modules/sd_models_config.py | 7 +++++-- 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/modules/api/models.py b/modules/api/models.py index 805bd8f7..cba43d3b 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -228,7 +228,7 @@ class SDModelItem(BaseModel): hash: Optional[str] = Field(title="Short hash") sha256: Optional[str] = Field(title="sha256 hash") filename: str = Field(title="Filename") - config: str = Field(title="Config file") + config: Optional[str] = Field(title="Config file") class HypernetworkItem(BaseModel): name: str = Field(title="Name") diff --git a/modules/sd_models.py b/modules/sd_models.py index fa208728..37dad18d 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -439,12 +439,12 @@ def reload_model_weights(sd_model=None, info=None): if sd_model.sd_model_checkpoint == checkpoint_info.filename: return - if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: - lowvram.send_everything_to_cpu() - else: - sd_model.to(devices.cpu) + if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: + lowvram.send_everything_to_cpu() + else: + sd_model.to(devices.cpu) - sd_hijack.model_hijack.undo_hijack(sd_model) + sd_hijack.model_hijack.undo_hijack(sd_model) timer = Timer() diff --git a/modules/sd_models_config.py b/modules/sd_models_config.py index ea773a10..4d1e92e1 100644 --- a/modules/sd_models_config.py +++ b/modules/sd_models_config.py @@ -10,6 +10,7 @@ sd_repo_configs_path = os.path.join(paths.paths['Stable Diffusion'], "configs", config_default = shared.sd_default_config config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference.yaml") config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml") +config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml") config_inpainting = os.path.join(sd_configs_path, "v1-inpainting-inference.yaml") config_instruct_pix2pix = os.path.join(sd_configs_path, "instruct-pix2pix.yaml") config_alt_diffusion = os.path.join(sd_configs_path, "alt-diffusion-inference.yaml") @@ -22,7 +23,9 @@ def guess_model_config_from_state_dict(sd, filename): sd2_cond_proj_weight = sd.get('cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight', None) diffusion_model_input = sd.get('model.diffusion_model.input_blocks.0.0.weight', None) - roberta_weight = sd.get('cond_stage_model.roberta.embeddings.word_embeddings.weight', None) + + if sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None: + return config_depth_model if sd2_cond_proj_weight is not None and sd2_cond_proj_weight.shape[1] == 1024: if re.search(re_parametrization_v, fn) or "v2-1_768" in fn: @@ -36,7 +39,7 @@ def guess_model_config_from_state_dict(sd, filename): if diffusion_model_input.shape[1] == 8: return config_instruct_pix2pix - if roberta_weight is not None: + if sd.get('cond_stage_model.roberta.embeddings.word_embeddings.weight', None) is not None: return config_alt_diffusion return config_default From 9beb794e0b0dc1a0f9e89d8e38bd789a8c608397 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 27 Jan 2023 13:08:00 +0300 Subject: [PATCH 080/151] clarify the option to disable NaN check. --- modules/devices.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/modules/devices.py b/modules/devices.py index 2d5f797a..4687944e 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -143,6 +143,8 @@ def test_for_nans(x, where): else: message = "A tensor with all NaNs was produced." + message += " Use --disable-nan-check commandline argument to disable this check." + raise NansException(message) From 9ecf1e827c5966e11495a0c066a127defbba9bcc Mon Sep 17 00:00:00 2001 From: Spaceginner Date: Fri, 27 Jan 2023 17:35:24 +0500 Subject: [PATCH 081/151] Made it only a warning --- .gitignore | 1 + launch.py | 39 ++++++++++++++++++++++++++++----------- 2 files changed, 29 insertions(+), 11 deletions(-) diff --git a/.gitignore b/.gitignore index 0b1d17ca..c8be9688 100644 --- a/.gitignore +++ b/.gitignore @@ -33,3 +33,4 @@ notification.mp3 /test/stdout.txt /test/stderr.txt /cache.json +no_py_ver_warning diff --git a/launch.py b/launch.py index 52f3bd52..4f5a4bc4 100644 --- a/launch.py +++ b/launch.py @@ -18,18 +18,35 @@ skip_install = False def check_python_version(): - version = sys.version_info - version_range = None - if platform.system() == "Linux": - version_range = range(7 + 1, 11 + 1) - else: - version_range = range(7 + 1, 10 + 1) + if not os.path.isfile("no_py_ver_warning"): + version = sys.version_info + version_range = None + if platform.system() == "Linux": + version_range = range(7 + 1, 11 + 1) + else: + version_range = range(7 + 1, 10 + 1) - try: - assert version.major == 3 and version.minor in version_range, "Unsupported Python version, please use Python 3.10.x instead. You can download latest release as of 25th January (3.10.9) from here: https://www.python.org/downloads/release/python-3109/. Please, make sure to first delete current version of Python first and delete `venv` folder inside of WebUI's folder, too." - except AssertionError as e: - print(e) - sys.exit(-1) + try: + assert version.major == 3 and version.minor in version_range, f""" +=== Warning === +This program was tested only with 3.10 Python, but you have {version.major}.{version.minor} Python. +If you encounter an error with "RuntimeError: Couldn't install torch." message, +or any other error regarding unsuccessful package (library) installation, +please downgrade (or upgrade) to the latest version of 3.10 Python +and delete current Python and "venv" folder in WebUI's directory. + +You can download 3.10 Python from here: https://www.python.org/downloads/release/python-3109/ + +You will see this warning only once, delete file "no_py_ver_warning" file to show this warning again. +=== Warning === + +Press ENTER to continue...\ +""" + except AssertionError as e: + print(e) + with open("no_py_ver_warning", "w"): + pass + input() def commit_hash(): From 5eee2ac39863f9e44591b50d0710dd2615416a13 Mon Sep 17 00:00:00 2001 From: Max Audron Date: Wed, 25 Jan 2023 17:15:42 +0100 Subject: [PATCH 082/151] add data-dir flag and set all user data directories based on it --- modules/extensions.py | 2 +- modules/generation_parameters_copypaste.py | 4 ++-- modules/gfpgan_model.py | 5 ++--- modules/hashes.py | 4 +++- modules/interrogate.py | 2 +- modules/paths.py | 10 +++++++++- modules/processing.py | 3 ++- modules/sd_models.py | 6 +++--- modules/sd_vae.py | 5 ++--- modules/shared.py | 11 ++++++----- modules/textual_inversion/preprocess.py | 5 ++--- modules/ui.py | 6 +++--- modules/ui_extensions.py | 2 +- modules/upscaler.py | 5 ++--- 14 files changed, 39 insertions(+), 31 deletions(-) diff --git a/modules/extensions.py b/modules/extensions.py index b522125c..92ee8144 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -7,7 +7,7 @@ import git from modules import paths, shared extensions = [] -extensions_dir = os.path.join(paths.script_path, "extensions") +extensions_dir = os.path.join(paths.data_path, "extensions") extensions_builtin_dir = os.path.join(paths.script_path, "extensions-builtin") diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index 46e12dc6..35f72808 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -6,7 +6,7 @@ import re from pathlib import Path import gradio as gr -from modules.shared import script_path +from modules.paths import data_path, script_path from modules import shared, ui_tempdir, script_callbacks import tempfile from PIL import Image @@ -289,7 +289,7 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model def connect_paste(button, paste_fields, input_comp, jsfunc=None): def paste_func(prompt): if not prompt and not shared.cmd_opts.hide_ui_dir_config: - filename = os.path.join(script_path, "params.txt") + filename = os.path.join(data_path, "params.txt") if os.path.exists(filename): with open(filename, "r", encoding="utf8") as file: prompt = file.read() diff --git a/modules/gfpgan_model.py b/modules/gfpgan_model.py index 1e2dbc32..fbe6215a 100644 --- a/modules/gfpgan_model.py +++ b/modules/gfpgan_model.py @@ -6,12 +6,11 @@ import facexlib import gfpgan import modules.face_restoration -from modules import shared, devices, modelloader -from modules.paths import models_path +from modules import paths, shared, devices, modelloader model_dir = "GFPGAN" user_path = None -model_path = os.path.join(models_path, model_dir) +model_path = os.path.join(paths.models_path, model_dir) model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth" have_gfpgan = False loaded_gfpgan_model = None diff --git a/modules/hashes.py b/modules/hashes.py index b85a7580..819362a3 100644 --- a/modules/hashes.py +++ b/modules/hashes.py @@ -4,8 +4,10 @@ import os.path import filelock +from modules.paths import data_path -cache_filename = "cache.json" + +cache_filename = os.path.join(data_path, "cache.json") cache_data = None diff --git a/modules/interrogate.py b/modules/interrogate.py index c72ff694..cbb80683 100644 --- a/modules/interrogate.py +++ b/modules/interrogate.py @@ -12,7 +12,7 @@ from torchvision import transforms from torchvision.transforms.functional import InterpolationMode import modules.shared as shared -from modules import devices, paths, lowvram, modelloader, errors +from modules import devices, paths, shared, lowvram, modelloader, errors blip_image_eval_size = 384 clip_model_name = 'ViT-L/14' diff --git a/modules/paths.py b/modules/paths.py index 20b3e4d8..08e6f9b9 100644 --- a/modules/paths.py +++ b/modules/paths.py @@ -4,7 +4,15 @@ import sys import modules.safe script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) -models_path = os.path.join(script_path, "models") + +# Parse the --data-dir flag first so we can use it as a base for our other argument default values +parser = argparse.ArgumentParser() +parser.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored",) +cmd_opts_pre = parser.parse_known_args()[0] +data_path = cmd_opts_pre.data_dir +models_path = os.path.join(data_path, "models") + +# data_path = cmd_opts_pre.data sys.path.insert(0, script_path) # search for directory of stable diffusion in following places diff --git a/modules/processing.py b/modules/processing.py index 262806a1..5072fc40 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -17,6 +17,7 @@ from modules import devices, prompt_parser, masking, sd_samplers, lowvram, gener from modules.sd_hijack import model_hijack from modules.shared import opts, cmd_opts, state import modules.shared as shared +import modules.paths as paths import modules.face_restoration import modules.images as images import modules.styles @@ -584,7 +585,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if not p.disable_extra_networks: extra_networks.activate(p, extra_network_data) - with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file: + with open(os.path.join(paths.data_path, "params.txt"), "w", encoding="utf8") as file: processed = Processed(p, [], p.seed, "") file.write(processed.infotext(p, 0)) diff --git a/modules/sd_models.py b/modules/sd_models.py index 37dad18d..b2d48a51 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -12,13 +12,13 @@ import ldm.modules.midas as midas from ldm.util import instantiate_from_config -from modules import shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config +from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config from modules.paths import models_path from modules.sd_hijack_inpainting import do_inpainting_hijack from modules.timer import Timer model_dir = "Stable-diffusion" -model_path = os.path.abspath(os.path.join(models_path, model_dir)) +model_path = os.path.abspath(os.path.join(paths.models_path, model_dir)) checkpoints_list = {} checkpoint_alisases = {} @@ -307,7 +307,7 @@ def enable_midas_autodownload(): location automatically. """ - midas_path = os.path.join(models_path, 'midas') + midas_path = os.path.join(paths.models_path, 'midas') # stable-diffusion-stability-ai hard-codes the midas model path to # a location that differs from where other scripts using this model look. diff --git a/modules/sd_vae.py b/modules/sd_vae.py index 4ce238b8..9b00f76e 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -3,13 +3,12 @@ import safetensors.torch import os import collections from collections import namedtuple -from modules import shared, devices, script_callbacks, sd_models -from modules.paths import models_path +from modules import paths, shared, devices, script_callbacks, sd_models import glob from copy import deepcopy -vae_path = os.path.abspath(os.path.join(models_path, "VAE")) +vae_path = os.path.abspath(os.path.join(paths.models_path, "VAE")) vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"} vae_dict = {} diff --git a/modules/shared.py b/modules/shared.py index 14be993d..474fcc42 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -14,7 +14,7 @@ import modules.memmon import modules.styles import modules.devices as devices from modules import localization, extensions, script_loading, errors, ui_components, shared_items -from modules.paths import models_path, script_path +from modules.paths import models_path, script_path, data_path demo = None @@ -25,6 +25,7 @@ sd_model_file = os.path.join(script_path, 'model.ckpt') default_sd_model_file = sd_model_file parser = argparse.ArgumentParser() +parser.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored",) parser.add_argument("--config", type=str, default=sd_default_config, help="path to config which constructs model",) parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",) parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints") @@ -35,7 +36,7 @@ parser.add_argument("--no-half", action='store_true', help="do not switch the mo parser.add_argument("--no-half-vae", action='store_true', help="do not switch the VAE model to 16-bit floats") parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)") parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI") -parser.add_argument("--embeddings-dir", type=str, default=os.path.join(script_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)") +parser.add_argument("--embeddings-dir", type=str, default=os.path.join(data_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)") parser.add_argument("--textual-inversion-templates-dir", type=str, default=os.path.join(script_path, 'textual_inversion_templates'), help="directory with textual inversion templates") parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_path, 'hypernetworks'), help="hypernetwork directory") parser.add_argument("--localizations-dir", type=str, default=os.path.join(script_path, 'localizations'), help="localizations directory") @@ -74,16 +75,16 @@ parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for sp parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests") parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None) parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False) -parser.add_argument("--ui-config-file", type=str, help="filename to use for ui configuration", default=os.path.join(script_path, 'ui-config.json')) +parser.add_argument("--ui-config-file", type=str, help="filename to use for ui configuration", default=os.path.join(data_path, 'ui-config.json')) parser.add_argument("--hide-ui-dir-config", action='store_true', help="hide directory configuration from webui", default=False) parser.add_argument("--freeze-settings", action='store_true', help="disable editing settings", default=False) -parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default=os.path.join(script_path, 'config.json')) +parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default=os.path.join(data_path, 'config.json')) parser.add_argument("--gradio-debug", action='store_true', help="launch gradio with --debug option") parser.add_argument("--gradio-auth", type=str, help='set gradio authentication like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None) parser.add_argument("--gradio-img2img-tool", type=str, help='does not do anything') parser.add_argument("--gradio-inpaint-tool", type=str, help="does not do anything") parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last") -parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(script_path, 'styles.csv')) +parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(data_path, 'styles.csv')) parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False) parser.add_argument("--theme", type=str, help="launches the UI with light or dark theme", default=None) parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False) diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py index c0ac11d3..2239cb84 100644 --- a/modules/textual_inversion/preprocess.py +++ b/modules/textual_inversion/preprocess.py @@ -6,8 +6,7 @@ import sys import tqdm import time -from modules import shared, images, deepbooru -from modules.paths import models_path +from modules import paths, shared, images, deepbooru from modules.shared import opts, cmd_opts from modules.textual_inversion import autocrop @@ -199,7 +198,7 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre dnn_model_path = None try: - dnn_model_path = autocrop.download_and_cache_models(os.path.join(models_path, "opencv")) + dnn_model_path = autocrop.download_and_cache_models(os.path.join(paths.models_path, "opencv")) except Exception as e: print("Unable to load face detection model for auto crop selection. Falling back to lower quality haar method.", e) diff --git a/modules/ui.py b/modules/ui.py index 85ae62c7..0117df3e 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -21,7 +21,7 @@ from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_grad from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, postprocessing, ui_components, ui_common, ui_postprocessing from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML -from modules.paths import script_path +from modules.paths import script_path, data_path from modules.shared import opts, cmd_opts, restricted_opts @@ -1497,8 +1497,8 @@ def create_ui(): with open(cssfile, "r", encoding="utf8") as file: css += file.read() + "\n" - if os.path.exists(os.path.join(script_path, "user.css")): - with open(os.path.join(script_path, "user.css"), "r", encoding="utf8") as file: + if os.path.exists(os.path.join(data_path, "user.css")): + with open(os.path.join(data_path, "user.css"), "r", encoding="utf8") as file: css += file.read() + "\n" if not cmd_opts.no_progressbar_hiding: diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index 742e745e..66a41865 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -132,7 +132,7 @@ def install_extension_from_url(dirname, url): normalized_url = normalize_git_url(url) assert len([x for x in extensions.extensions if normalize_git_url(x.remote) == normalized_url]) == 0, 'Extension with this URL is already installed' - tmpdir = os.path.join(paths.script_path, "tmp", dirname) + tmpdir = os.path.join(paths.data_path, "tmp", dirname) try: shutil.rmtree(tmpdir, True) diff --git a/modules/upscaler.py b/modules/upscaler.py index a5bf5acb..e2eaa730 100644 --- a/modules/upscaler.py +++ b/modules/upscaler.py @@ -11,7 +11,6 @@ from modules import modelloader, shared LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS) NEAREST = (Image.Resampling.NEAREST if hasattr(Image, 'Resampling') else Image.NEAREST) -from modules.paths import models_path class Upscaler: @@ -39,7 +38,7 @@ class Upscaler: self.mod_scale = None if self.model_path is None and self.name: - self.model_path = os.path.join(models_path, self.name) + self.model_path = os.path.join(shared.models_path, self.name) if self.model_path and create_dirs: os.makedirs(self.model_path, exist_ok=True) @@ -143,4 +142,4 @@ class UpscalerNearest(Upscaler): def __init__(self, dirname=None): super().__init__(False) self.name = "Nearest" - self.scalers = [UpscalerData("Nearest", None, self)] \ No newline at end of file + self.scalers = [UpscalerData("Nearest", None, self)] From 14c0884fd0948c478db165989cca7aaffc9a0504 Mon Sep 17 00:00:00 2001 From: Max Audron Date: Wed, 25 Jan 2023 17:55:59 +0100 Subject: [PATCH 083/151] use python importlib to load and execute extension modules previously module attributes like __file__ where not set correctly, leading to scripts getting the directory of the stable-diffusion repo location instead of their own script. This causes problem when loading user data from an external location using the --data-dir flag, as extensions would look for their own code in the stable-diffusion repo location instead of the data dir location. Using pythons importlib functions sets the modules specs correctly and executes them. But this will break extensions if they build paths based on the previously incorrect __file__ attribute. --- modules/script_loading.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/modules/script_loading.py b/modules/script_loading.py index f93f0951..a7d2203f 100644 --- a/modules/script_loading.py +++ b/modules/script_loading.py @@ -1,16 +1,14 @@ import os import sys import traceback +import importlib.util from types import ModuleType def load_module(path): - with open(path, "r", encoding="utf8") as file: - text = file.read() - - compiled = compile(text, path, 'exec') - module = ModuleType(os.path.basename(path)) - exec(compiled, module.__dict__) + module_spec = importlib.util.spec_from_file_location(os.path.basename(path), path) + module = importlib.util.module_from_spec(module_spec) + module_spec.loader.exec_module(module) return module From 6b3981c0685cd1df750df4eb51823f1cfd70c6d5 Mon Sep 17 00:00:00 2001 From: Max Audron Date: Wed, 25 Jan 2023 18:00:09 +0100 Subject: [PATCH 084/151] clean up unused script_path imports --- modules/codeformer_model.py | 2 +- modules/generation_parameters_copypaste.py | 2 +- webui.py | 1 - 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/modules/codeformer_model.py b/modules/codeformer_model.py index ab40d842..01fb7bd8 100644 --- a/modules/codeformer_model.py +++ b/modules/codeformer_model.py @@ -8,7 +8,7 @@ import torch import modules.face_restoration import modules.shared from modules import shared, devices, modelloader -from modules.paths import script_path, models_path +from modules.paths import models_path # codeformer people made a choice to include modified basicsr library to their project which makes # it utterly impossible to use it alongside with other libraries that also use basicsr, like GFPGAN. diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index 35f72808..773c5c0e 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -6,7 +6,7 @@ import re from pathlib import Path import gradio as gr -from modules.paths import data_path, script_path +from modules.paths import data_path from modules import shared, ui_tempdir, script_callbacks import tempfile from PIL import Image diff --git a/webui.py b/webui.py index e1565a8d..41f32f5c 100644 --- a/webui.py +++ b/webui.py @@ -15,7 +15,6 @@ logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not from modules import import_hook, errors, extra_networks from modules import extra_networks_hypernet, ui_extra_networks_hypernets, ui_extra_networks_textual_inversion from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call -from modules.paths import script_path import torch From 23a9d5e27390846dea0895a02c04aec9583a4d38 Mon Sep 17 00:00:00 2001 From: Max Audron Date: Wed, 25 Jan 2023 18:18:55 +0100 Subject: [PATCH 085/151] create user extensions directory if not exists --- modules/extensions.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/modules/extensions.py b/modules/extensions.py index 92ee8144..5e12b1aa 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -10,6 +10,8 @@ extensions = [] extensions_dir = os.path.join(paths.data_path, "extensions") extensions_builtin_dir = os.path.join(paths.script_path, "extensions-builtin") +if not os.path.exists(extensions_dir): + os.makedirs(extensions_dir) def active(): return [x for x in extensions if x.enabled] From eafaf14167cf574ad0f918c10f60ef86aea9cd20 Mon Sep 17 00:00:00 2001 From: Gazzoo-byte <73721238+Gazzoo-byte@users.noreply.github.com> Date: Fri, 27 Jan 2023 18:34:41 +0000 Subject: [PATCH 086/151] Add button to switch width and height Adds a button to switch width and height, allowing quick and easy switching between landscape and portrait. --- modules/ui.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/modules/ui.py b/modules/ui.py index 85ae62c7..fb0e4d5c 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -91,6 +91,13 @@ save_style_symbol = '\U0001f4be' # 💾 apply_style_symbol = '\U0001f4cb' # 📋 clear_prompt_symbol = '\U0001F5D1' # 🗑️ extra_networks_symbol = '\U0001F3B4' # 🎴 +switch_values_symbol = '\U000021C5' # ⇅ + +def switch_width_and_height(width, height): + width_temp = width + width = height + height = width_temp + return width, height def plaintext_to_html(text): @@ -466,6 +473,7 @@ def create_ui(): height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="txt2img_height") if opts.dimensions_and_batch_together: + res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="txt2img_res_switch_btn") with gr.Column(elem_id="txt2img_column_batch"): batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count") batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size") @@ -566,6 +574,7 @@ def create_ui(): txt2img_prompt.submit(**txt2img_args) submit.click(**txt2img_args) + res_switch_btn.click(switch_width_and_height, inputs=[width, height], outputs=[width, height]) txt_prompt_img.change( fn=modules.images.image_data, @@ -728,6 +737,7 @@ def create_ui(): height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="img2img_height") if opts.dimensions_and_batch_together: + res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="img2img_res_switch_btn") with gr.Column(elem_id="img2img_column_batch"): batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count") batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size") @@ -865,6 +875,7 @@ def create_ui(): img2img_prompt.submit(**img2img_args) submit.click(**img2img_args) + res_switch_btn.click(switch_width_and_height, inputs=[width, height], outputs=[width, height]) img2img_interrogate.click( fn=lambda *args: process_interrogate(interrogate, *args), From a6a5bfb15531b19ce0319593d67d05a356f49a65 Mon Sep 17 00:00:00 2001 From: EllangoK Date: Fri, 27 Jan 2023 13:48:39 -0500 Subject: [PATCH 087/151] deepcopy pc.styles, allows for multiple style axis --- scripts/xyz_grid.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py index 828c2d12..f2fe506c 100644 --- a/scripts/xyz_grid.py +++ b/scripts/xyz_grid.py @@ -123,7 +123,7 @@ def apply_vae(p, x, xs): def apply_styles(p: StableDiffusionProcessingTxt2Img, x: str, _): - p.styles = x.split(',') + p.styles.extend(x.split(',')) def format_value_add_label(p, opt, x): @@ -533,6 +533,7 @@ class Script(scripts.Script): return Processed(p, [], p.seed, "") pc = copy(p) + pc.styles = pc.styles[:] x_opt.apply(pc, x, xs) y_opt.apply(pc, y, ys) z_opt.apply(pc, z, zs) From 32d389ef0f7c75dd85fc7aebe7bca279f36fed86 Mon Sep 17 00:00:00 2001 From: EllangoK Date: Fri, 27 Jan 2023 14:04:23 -0500 Subject: [PATCH 088/151] changes remaining text from X/Y -> X/Y/Z --- README.md | 2 +- javascript/hints.js | 2 +- scripts/xyz_grid.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index c6bd6f27..2149dcc5 100644 --- a/README.md +++ b/README.md @@ -17,7 +17,7 @@ A browser interface based on Gradio library for Stable Diffusion. - a man in a (tuxedo:1.21) - alternative syntax - select text and press ctrl+up or ctrl+down to automatically adjust attention to selected text (code contributed by anonymous user) - Loopback, run img2img processing multiple times -- X/Y plot, a way to draw a 2 dimensional plot of images with different parameters +- X/Y/Z plot, a way to draw a 3 dimensional plot of images with different parameters - Textual Inversion - have as many embeddings as you want and use any names you like for them - use multiple embeddings with different numbers of vectors per token diff --git a/javascript/hints.js b/javascript/hints.js index 3cf10e20..7b60b25e 100644 --- a/javascript/hints.js +++ b/javascript/hints.js @@ -50,7 +50,7 @@ titles = { "None": "Do not do anything special", "Prompt matrix": "Separate prompts into parts using vertical pipe character (|) and the script will create a picture for every combination of them (except for the first part, which will be present in all combinations)", - "X/Y plot": "Create a grid where images will have different parameters. Use inputs below to specify which parameters will be shared by columns and rows", + "X/Y/Z plot": "Create grid(s) where images will have different parameters. Use inputs below to specify which parameters will be shared by columns and rows", "Custom code": "Run Python code. Advanced user only. Must run program with --allow-code for this to work", "Prompt S/R": "Separate a list of words with commas, and the first word will be used as a keyword: script will search for this word in the prompt, and replace it with others", diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py index f2fe506c..f0116055 100644 --- a/scripts/xyz_grid.py +++ b/scripts/xyz_grid.py @@ -499,7 +499,7 @@ class Script(scripts.Script): image_cell_count = p.n_iter * p.batch_size cell_console_text = f"; {image_cell_count} images per cell" if image_cell_count > 1 else "" plural_s = 's' if len(zs) > 1 else '' - print(f"X/Y plot will create {len(xs) * len(ys) * len(zs) * image_cell_count} images on {len(zs)} {len(xs)}x{len(ys)} grid{plural_s}{cell_console_text}. (Total steps to process: {total_steps})") + print(f"X/Y/Z plot will create {len(xs) * len(ys) * len(zs) * image_cell_count} images on {len(zs)} {len(xs)}x{len(ys)} grid{plural_s}{cell_console_text}. (Total steps to process: {total_steps})") shared.total_tqdm.updateTotal(total_steps) grid_infotext = [None] From cc8c9b7474d917888a0bd069fcd59a458c67ae4b Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 27 Jan 2023 22:43:08 +0300 Subject: [PATCH 089/151] fix broken calls to find_checkpoint_config --- modules/extras.py | 4 ++-- modules/sd_hijack_ip2p.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/modules/extras.py b/modules/extras.py index 36123aa5..4f842be9 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -6,7 +6,7 @@ import shutil import torch import tqdm -from modules import shared, images, sd_models, sd_vae +from modules import shared, images, sd_models, sd_vae, sd_models_config from modules.ui_common import plaintext_to_html import gradio as gr import safetensors.torch @@ -37,7 +37,7 @@ def run_pnginfo(image): def create_config(ckpt_result, config_source, a, b, c): def config(x): - res = sd_models.find_checkpoint_config(x) if x else None + res = sd_models_config.find_checkpoint_config_near_filename(x) if x else None return res if res != shared.sd_default_config else None if config_source == 0: diff --git a/modules/sd_hijack_ip2p.py b/modules/sd_hijack_ip2p.py index 635f015f..3c727d3b 100644 --- a/modules/sd_hijack_ip2p.py +++ b/modules/sd_hijack_ip2p.py @@ -5,9 +5,9 @@ import gc import time def should_hijack_ip2p(checkpoint_info): - from modules import sd_models + from modules import sd_models_config ckpt_basename = os.path.basename(checkpoint_info.filename).lower() - cfg_basename = os.path.basename(sd_models.find_checkpoint_config(checkpoint_info)).lower() + cfg_basename = os.path.basename(sd_models_config.find_checkpoint_config_near_filename(checkpoint_info)).lower() return "pix2pix" in ckpt_basename and not "pix2pix" in cfg_basename From 6b82efd737827bbeef202f04ff5a8faec9b64ef8 Mon Sep 17 00:00:00 2001 From: MrCheeze Date: Fri, 27 Jan 2023 20:06:19 -0500 Subject: [PATCH 090/151] add v2-inpainting model detection, and broaden v-model detection to include anything with 768 in the name --- modules/sd_models_config.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/modules/sd_models_config.py b/modules/sd_models_config.py index 4d1e92e1..73854a45 100644 --- a/modules/sd_models_config.py +++ b/modules/sd_models_config.py @@ -10,6 +10,7 @@ sd_repo_configs_path = os.path.join(paths.paths['Stable Diffusion'], "configs", config_default = shared.sd_default_config config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference.yaml") config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml") +config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml") config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml") config_inpainting = os.path.join(sd_configs_path, "v1-inpainting-inference.yaml") config_instruct_pix2pix = os.path.join(sd_configs_path, "instruct-pix2pix.yaml") @@ -28,7 +29,9 @@ def guess_model_config_from_state_dict(sd, filename): return config_depth_model if sd2_cond_proj_weight is not None and sd2_cond_proj_weight.shape[1] == 1024: - if re.search(re_parametrization_v, fn) or "v2-1_768" in fn: + if diffusion_model_input.shape[1] == 9: + return config_sd2_inpainting + elif re.search(re_parametrization_v, fn) or "768" in fn: return config_sd2v else: return config_sd2 From 2aac1d97782b486f3a4a5209cf399dcdcb7bbb4d Mon Sep 17 00:00:00 2001 From: Andrii Skaliuk Date: Fri, 27 Jan 2023 17:32:31 -0800 Subject: [PATCH 091/151] Basic inpainting batch support Modifies batch UI to add optional inpainting support --- modules/img2img.py | 20 +++++++++++++++++--- modules/ui.py | 9 ++++++++- 2 files changed, 25 insertions(+), 4 deletions(-) diff --git a/modules/img2img.py b/modules/img2img.py index 2168c8e2..fe9447c7 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -16,11 +16,16 @@ import modules.images as images import modules.scripts -def process_batch(p, input_dir, output_dir, args): +def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args): processing.fix_seed(p) images = shared.listfiles(input_dir) + inpaint_masks = shared.listfiles(inpaint_mask_dir) + is_inpaint_batch = inpaint_mask_dir and len(inpaint_masks) > 0 + if is_inpaint_batch: + print(f"\nInpaint batch is enabled. {len(inpaint_masks)} masks found.") + print(f"Will process {len(images)} images, creating {p.n_iter * p.batch_size} new images for each.") save_normally = output_dir == '' @@ -43,6 +48,15 @@ def process_batch(p, input_dir, output_dir, args): img = ImageOps.exif_transpose(img) p.init_images = [img] * p.batch_size + if is_inpaint_batch: + # try to find corresponding mask for an image using simple filename matching + mask_image_path = os.path.join(inpaint_mask_dir, os.path.basename(image)) + # if not found use first one ("same mask for all images" use-case) + if not mask_image_path in inpaint_masks: + mask_image_path = inpaint_masks[0] + mask_image = Image.open(mask_image_path) + p.image_mask = mask_image + proc = modules.scripts.scripts_img2img.run(p, *args) if proc is None: proc = process_images(p) @@ -59,7 +73,7 @@ def process_batch(p, input_dir, output_dir, args): processed_image.save(os.path.join(output_dir, filename)) -def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args): +def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, *args): is_batch = mode == 5 if mode == 0: # img2img @@ -139,7 +153,7 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s if is_batch: assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled" - process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, args) + process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, args) processed = Processed(p, [], p.seed, "") else: diff --git a/modules/ui.py b/modules/ui.py index 85ae62c7..fddb9177 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -691,9 +691,15 @@ def create_ui(): with gr.TabItem('Batch', id='batch', elem_id="img2img_batch_tab") as tab_batch: hidden = '
Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else '' - gr.HTML(f"

Process images in a directory on the same machine where the server is running.
Use an empty output directory to save pictures normally instead of writing to the output directory.{hidden}

") + gr.HTML( + f"

Process images in a directory on the same machine where the server is running." + + f"
Use an empty output directory to save pictures normally instead of writing to the output directory." + + f"
Add inpaint batch mask directory to enable inpaint batch processing." + f"{hidden}

" + ) img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, elem_id="img2img_batch_input_dir") img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir") + img2img_batch_inpaint_mask_dir = gr.Textbox(label="Inpaint batch mask directory (required for inpaint batch processing only)", **shared.hide_dirs, elem_id="img2img_batch_inpaint_mask_dir") def copy_image(img): if isinstance(img, dict) and 'image' in img: @@ -838,6 +844,7 @@ def create_ui(): inpainting_mask_invert, img2img_batch_input_dir, img2img_batch_output_dir, + img2img_batch_inpaint_mask_dir ] + custom_inputs, outputs=[ img2img_gallery, From 4c52dfe4ac98c53431ecd267d59f27391d3a63e7 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 28 Jan 2023 08:30:17 +0300 Subject: [PATCH 092/151] make the detection for -v models less broad --- modules/sd_models_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/sd_models_config.py b/modules/sd_models_config.py index 73854a45..00217990 100644 --- a/modules/sd_models_config.py +++ b/modules/sd_models_config.py @@ -31,7 +31,7 @@ def guess_model_config_from_state_dict(sd, filename): if sd2_cond_proj_weight is not None and sd2_cond_proj_weight.shape[1] == 1024: if diffusion_model_input.shape[1] == 9: return config_sd2_inpainting - elif re.search(re_parametrization_v, fn) or "768" in fn: + elif re.search(re_parametrization_v, fn): return config_sd2v else: return config_sd2 From 0834d4ce374225131e025540220c727e352a3e43 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 28 Jan 2023 08:41:15 +0300 Subject: [PATCH 093/151] simplify #7284 --- modules/ui.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/modules/ui.py b/modules/ui.py index 3c0a4050..ca2c1eb6 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -93,12 +93,6 @@ clear_prompt_symbol = '\U0001F5D1' # 🗑️ extra_networks_symbol = '\U0001F3B4' # 🎴 switch_values_symbol = '\U000021C5' # ⇅ -def switch_width_and_height(width, height): - width_temp = width - width = height - height = width_temp - return width, height - def plaintext_to_html(text): return ui_common.plaintext_to_html(text) @@ -574,7 +568,8 @@ def create_ui(): txt2img_prompt.submit(**txt2img_args) submit.click(**txt2img_args) - res_switch_btn.click(switch_width_and_height, inputs=[width, height], outputs=[width, height]) + + res_switch_btn.click(lambda w, h: (h, w), inputs=[width, height], outputs=[width, height]) txt_prompt_img.change( fn=modules.images.image_data, @@ -882,7 +877,7 @@ def create_ui(): img2img_prompt.submit(**img2img_args) submit.click(**img2img_args) - res_switch_btn.click(switch_width_and_height, inputs=[width, height], outputs=[width, height]) + res_switch_btn.click(lambda w, h: (h, w), inputs=[width, height], outputs=[width, height]) img2img_interrogate.click( fn=lambda *args: process_interrogate(interrogate, *args), From 7d1f2a3a495327341ef1b3238347864845799bb6 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 28 Jan 2023 10:21:31 +0300 Subject: [PATCH 094/151] remove waiting for input on version mismatch warning, change supported versions --- .gitignore | 1 - launch.py | 35 ++++++++++++----------------------- 2 files changed, 12 insertions(+), 24 deletions(-) diff --git a/.gitignore b/.gitignore index c8be9688..0b1d17ca 100644 --- a/.gitignore +++ b/.gitignore @@ -33,4 +33,3 @@ notification.mp3 /test/stdout.txt /test/stderr.txt /cache.json -no_py_ver_warning diff --git a/launch.py b/launch.py index 4f5a4bc4..7614f9c9 100644 --- a/launch.py +++ b/launch.py @@ -18,35 +18,24 @@ skip_install = False def check_python_version(): - if not os.path.isfile("no_py_ver_warning"): - version = sys.version_info - version_range = None - if platform.system() == "Linux": - version_range = range(7 + 1, 11 + 1) - else: - version_range = range(7 + 1, 10 + 1) + version = sys.version_info + if platform.system() == "Windows": + supported_minors = [10] + else: + supported_minors = [7, 8, 9, 10, 11] - try: - assert version.major == 3 and version.minor in version_range, f""" -=== Warning === -This program was tested only with 3.10 Python, but you have {version.major}.{version.minor} Python. + if not (version.major == 3 and version.minor in supported_minors): + import modules.errors + + modules.errors.print_error_explanation(f""" +This program is tested with 3.10.6 Python, but you have {version.major}.{version.minor}.{version.micro}. If you encounter an error with "RuntimeError: Couldn't install torch." message, or any other error regarding unsuccessful package (library) installation, please downgrade (or upgrade) to the latest version of 3.10 Python and delete current Python and "venv" folder in WebUI's directory. -You can download 3.10 Python from here: https://www.python.org/downloads/release/python-3109/ - -You will see this warning only once, delete file "no_py_ver_warning" file to show this warning again. -=== Warning === - -Press ENTER to continue...\ -""" - except AssertionError as e: - print(e) - with open("no_py_ver_warning", "w"): - pass - input() +You can download 3.10 Python from here: https://www.python.org/downloads/release/python-3109/\ +""") def commit_hash(): From 3752aad23d4be4522f9edf3fe79c1122fa5ad509 Mon Sep 17 00:00:00 2001 From: Mackerel Date: Sat, 28 Jan 2023 02:44:12 -0500 Subject: [PATCH 095/151] don't replace regular --help with new paths.py parser help --- modules/paths.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/paths.py b/modules/paths.py index 08e6f9b9..d991cc71 100644 --- a/modules/paths.py +++ b/modules/paths.py @@ -6,7 +6,7 @@ import modules.safe script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) # Parse the --data-dir flag first so we can use it as a base for our other argument default values -parser = argparse.ArgumentParser() +parser = argparse.ArgumentParser(add_help=False) parser.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored",) cmd_opts_pre = parser.parse_known_args()[0] data_path = cmd_opts_pre.data_dir From bd52a6d89970cca4f0f8b4275db895c99e173b3f Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 28 Jan 2023 10:48:08 +0300 Subject: [PATCH 096/151] some more changes for python version warning; add a commandline flag to disable --- launch.py | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/launch.py b/launch.py index 7614f9c9..370920de 100644 --- a/launch.py +++ b/launch.py @@ -18,23 +18,33 @@ skip_install = False def check_python_version(): - version = sys.version_info - if platform.system() == "Windows": + is_windows = platform.system() == "Windows" + major = sys.version_info.major + minor = sys.version_info.minor + micro = sys.version_info.micro + + if is_windows: supported_minors = [10] else: supported_minors = [7, 8, 9, 10, 11] - if not (version.major == 3 and version.minor in supported_minors): + if not (major == 3 and minor in supported_minors): import modules.errors modules.errors.print_error_explanation(f""" -This program is tested with 3.10.6 Python, but you have {version.major}.{version.minor}.{version.micro}. +INCOMPATIBLE PYTHON VERSION + +This program is tested with 3.10.6 Python, but you have {major}.{minor}.{micro}. If you encounter an error with "RuntimeError: Couldn't install torch." message, or any other error regarding unsuccessful package (library) installation, please downgrade (or upgrade) to the latest version of 3.10 Python and delete current Python and "venv" folder in WebUI's directory. -You can download 3.10 Python from here: https://www.python.org/downloads/release/python-3109/\ +You can download 3.10 Python from here: https://www.python.org/downloads/release/python-3109/ + +{"Alternatively, use a binary release of WebUI: https://github.com/AUTOMATIC1111/stable-diffusion-webui/releases" if is_windows else ""} + +Use --skip-python-version-check to suppress this warning. """) @@ -237,6 +247,7 @@ def prepare_environment(): sys.argv, _ = extract_arg(sys.argv, '-f') sys.argv, skip_torch_cuda_test = extract_arg(sys.argv, '--skip-torch-cuda-test') + sys.argv, skip_python_version_check = extract_arg(sys.argv, '--skip-python-version-check') sys.argv, reinstall_xformers = extract_arg(sys.argv, '--reinstall-xformers') sys.argv, reinstall_torch = extract_arg(sys.argv, '--reinstall-torch') sys.argv, update_check = extract_arg(sys.argv, '--update-check') @@ -245,6 +256,9 @@ def prepare_environment(): xformers = '--xformers' in sys.argv ngrok = '--ngrok' in sys.argv + if not skip_python_version_check: + check_python_version() + commit = commit_hash() print(f"Python {sys.version}") @@ -342,6 +356,5 @@ def start(): if __name__ == "__main__": - check_python_version() prepare_environment() start() From ada17dbd7c4c68a4e559848d2e6f2a7799722806 Mon Sep 17 00:00:00 2001 From: brkirch Date: Fri, 27 Jan 2023 10:19:43 -0500 Subject: [PATCH 097/151] Refactor conditional casting, fix upscalers --- modules/devices.py | 8 ++++++++ modules/processing.py | 15 ++++++++------- modules/realesrgan_model.py | 2 +- modules/sd_hijack.py | 2 +- modules/sd_hijack_unet.py | 8 +++++++- 5 files changed, 25 insertions(+), 10 deletions(-) diff --git a/modules/devices.py b/modules/devices.py index 6b36622c..0100e4af 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -83,6 +83,14 @@ dtype_unet = torch.float16 unet_needs_upcast = False +def cond_cast_unet(input): + return input.to(dtype_unet) if unet_needs_upcast else input + + +def cond_cast_float(input): + return input.float() if unet_needs_upcast else input + + def randn(seed, shape): torch.manual_seed(seed) if device.type == 'mps': diff --git a/modules/processing.py b/modules/processing.py index 92894d67..a397702b 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -172,8 +172,7 @@ class StableDiffusionProcessing: midas_in = torch.from_numpy(transformed["midas_in"][None, ...]).to(device=shared.device) midas_in = repeat(midas_in, "1 ... -> n ...", n=self.batch_size) - conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image.to(devices.dtype_vae) if devices.unet_needs_upcast else source_image)) - conditioning_image = conditioning_image.float() if devices.unet_needs_upcast else conditioning_image + conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image)) conditioning = torch.nn.functional.interpolate( self.sd_model.depth_model(midas_in), size=conditioning_image.shape[2:], @@ -217,7 +216,7 @@ class StableDiffusionProcessing: ) # Encode the new masked image using first stage of network. - conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image.to(devices.dtype_vae) if devices.unet_needs_upcast else conditioning_image)) + conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image)) # Create the concatenated conditioning tensor to be fed to `c_concat` conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=latent_image.shape[-2:]) @@ -228,16 +227,18 @@ class StableDiffusionProcessing: return image_conditioning def img2img_image_conditioning(self, source_image, latent_image, image_mask=None): + source_image = devices.cond_cast_float(source_image) + # HACK: Using introspection as the Depth2Image model doesn't appear to uniquely # identify itself with a field common to all models. The conditioning_key is also hybrid. if isinstance(self.sd_model, LatentDepth2ImageDiffusion): - return self.depth2img_image_conditioning(source_image.float() if devices.unet_needs_upcast else source_image) + return self.depth2img_image_conditioning(source_image) if self.sd_model.cond_stage_key == "edit": return self.edit_image_conditioning(source_image) if self.sampler.conditioning_key in {'hybrid', 'concat'}: - return self.inpainting_image_conditioning(source_image.float() if devices.unet_needs_upcast else source_image, latent_image, image_mask=image_mask) + return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask) # Dummy zero conditioning if we're not using inpainting or depth model. return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1) @@ -417,7 +418,7 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see def decode_first_stage(model, x): with devices.autocast(disable=x.dtype == devices.dtype_vae): - x = model.decode_first_stage(x.to(devices.dtype_vae) if devices.unet_needs_upcast else x) + x = model.decode_first_stage(x) return x @@ -1001,7 +1002,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): image = torch.from_numpy(batch_images) image = 2. * image - 1. - image = image.to(device=shared.device, dtype=devices.dtype_vae if devices.unet_needs_upcast else None) + image = image.to(shared.device) self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image)) diff --git a/modules/realesrgan_model.py b/modules/realesrgan_model.py index 47f70251..aad4a629 100644 --- a/modules/realesrgan_model.py +++ b/modules/realesrgan_model.py @@ -46,7 +46,7 @@ class UpscalerRealESRGAN(Upscaler): scale=info.scale, model_path=info.local_data_path, model=info.model(), - half=not cmd_opts.no_half, + half=not cmd_opts.no_half and not cmd_opts.upcast_sampling, tile=opts.ESRGAN_tile, tile_pad=opts.ESRGAN_tile_overlap, ) diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 531790f3..8fc91882 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -171,7 +171,7 @@ class EmbeddingsWithFixes(torch.nn.Module): vecs = [] for fixes, tensor in zip(batch_fixes, inputs_embeds): for offset, embedding in fixes: - emb = embedding.vec.to(devices.dtype_unet) if devices.unet_needs_upcast else embedding.vec + emb = devices.cond_cast_unet(embedding.vec) emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0]) tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]]) diff --git a/modules/sd_hijack_unet.py b/modules/sd_hijack_unet.py index a6ee577c..45cf2b18 100644 --- a/modules/sd_hijack_unet.py +++ b/modules/sd_hijack_unet.py @@ -55,8 +55,14 @@ class GELUHijack(torch.nn.GELU, torch.nn.Module): unet_needs_upcast = lambda *args, **kwargs: devices.unet_needs_upcast CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model, unet_needs_upcast) -CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).to(devices.dtype_unet), unet_needs_upcast) +CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast) if version.parse(torch.__version__) <= version.parse("1.13.1"): CondFunc('ldm.modules.diffusionmodules.util.GroupNorm32.forward', lambda orig_func, self, *args, **kwargs: orig_func(self.float(), *args, **kwargs), unet_needs_upcast) CondFunc('ldm.modules.attention.GEGLU.forward', lambda orig_func, self, x: orig_func(self.float(), x.float()).to(devices.dtype_unet), unet_needs_upcast) CondFunc('open_clip.transformer.ResidualAttentionBlock.__init__', lambda orig_func, *args, **kwargs: kwargs.update({'act_layer': GELUHijack}) and False or orig_func(*args, **kwargs), lambda _, *args, **kwargs: kwargs.get('act_layer') is None or kwargs['act_layer'] == torch.nn.GELU) + +first_stage_cond = lambda _, self, *args, **kwargs: devices.unet_needs_upcast and self.model.diffusion_model.dtype == torch.float16 +first_stage_sub = lambda orig_func, self, x, **kwargs: orig_func(self, x.to(devices.dtype_vae), **kwargs) +CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond) +CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond) +CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.get_first_stage_encoding', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).float(), first_stage_cond) From 02b8b957d763d0fc29551d13d8a2005615e8ce7a Mon Sep 17 00:00:00 2001 From: brkirch Date: Sat, 28 Jan 2023 00:16:22 -0500 Subject: [PATCH 098/151] Add --no-half-vae to default macOS arguments Apparently the version of PyTorch macOS users are currently at doesn't always handle half precision VAEs correctly. We will probably want to update the default PyTorch version to 2.0 when it comes out which should fix that, and at this point nightly builds of PyTorch 2.0 are going to be recommended for most Mac users. Unfortunately someone has already reported that their M2 Mac doesn't work with the nightly PyTorch 2.0 build currently, so we can add --no-half-vae for now and give users that can install nightly PyTorch 2.0 builds a webui-user.sh configuration that overrides the default. --- webui-macos-env.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/webui-macos-env.sh b/webui-macos-env.sh index fa187dd1..37cac4fb 100644 --- a/webui-macos-env.sh +++ b/webui-macos-env.sh @@ -10,7 +10,7 @@ then fi export install_dir="$HOME" -export COMMANDLINE_ARGS="--skip-torch-cuda-test --upcast-sampling --use-cpu interrogate" +export COMMANDLINE_ARGS="--skip-torch-cuda-test --upcast-sampling --no-half-vae --use-cpu interrogate" export TORCH_COMMAND="pip install torch==1.12.1 torchvision==0.13.1" export K_DIFFUSION_REPO="https://github.com/brkirch/k-diffusion.git" export K_DIFFUSION_COMMIT_HASH="51c9778f269cedb55a4d88c79c0246d35bdadb71" From f9edd578e9e29d160e6d56038bb368dc49895d64 Mon Sep 17 00:00:00 2001 From: brkirch Date: Sat, 28 Jan 2023 00:20:30 -0500 Subject: [PATCH 099/151] Remove MPS fix no longer needed for PyTorch The torch.narrow fix was required for nightly PyTorch builds for a while to prevent a hard crash, but newer nightly builds don't have this issue. --- modules/devices.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/modules/devices.py b/modules/devices.py index 0100e4af..be542f8f 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -201,6 +201,3 @@ if has_mps(): cumsum_needs_bool_fix = not torch.BoolTensor([True,True]).to(device=torch.device("mps"), dtype=torch.int64).equal(torch.BoolTensor([True,False]).to(torch.device("mps")).cumsum(0)) torch.cumsum = lambda input, *args, **kwargs: ( cumsum_fix(input, orig_cumsum, *args, **kwargs) ) torch.Tensor.cumsum = lambda self, *args, **kwargs: ( cumsum_fix(self, orig_Tensor_cumsum, *args, **kwargs) ) - orig_narrow = torch.narrow - torch.narrow = lambda *args, **kwargs: ( orig_narrow(*args, **kwargs).clone() ) - From 4aa7f5b5b996c1e3d97640e746f040a23a124860 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 28 Jan 2023 11:11:47 +0300 Subject: [PATCH 100/151] update image parameters regex for #7231 --- modules/generation_parameters_copypaste.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index 773c5c0e..1bf35bbb 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -11,7 +11,7 @@ from modules import shared, ui_tempdir, script_callbacks import tempfile from PIL import Image -re_param_code = r'\s*([\w ]+):\s*("(?:\\|\"|[^\"])+"|[^,]*)(?:,|$)' +re_param_code = r'\s*([\w ]+):\s*("(?:\\"[^,]|\\"|\\|[^\"])+"|[^,]*)(?:,|$)' re_param = re.compile(re_param_code) re_params = re.compile(r"^(?:" + re_param_code + "){3,}$") re_imagesize = re.compile(r"^(\d+)x(\d+)$") From d04e3e921e8ee71442a1f4a1d6e91c05b8238007 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 28 Jan 2023 15:24:29 +0300 Subject: [PATCH 101/151] automatically detect v-parameterization for SD2 checkpoints --- modules/sd_hijack.py | 2 ++ modules/sd_models_config.py | 51 +++++++++++++++++++++++++++++++++---- 2 files changed, 48 insertions(+), 5 deletions(-) diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index f9652d21..03897b2a 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -131,6 +131,8 @@ class StableDiffusionModelHijack: m.cond_stage_model.wrapped.model.token_embedding = m.cond_stage_model.wrapped.model.token_embedding.wrapped m.cond_stage_model = m.cond_stage_model.wrapped + undo_optimizations() + self.apply_circular(False) self.layers = None self.clip = None diff --git a/modules/sd_models_config.py b/modules/sd_models_config.py index 00217990..91c21700 100644 --- a/modules/sd_models_config.py +++ b/modules/sd_models_config.py @@ -1,7 +1,9 @@ import re import os -from modules import shared, paths +import torch + +from modules import shared, paths, sd_disable_initialization sd_configs_path = shared.sd_configs_path sd_repo_configs_path = os.path.join(paths.paths['Stable Diffusion'], "configs", "stable-diffusion") @@ -16,12 +18,51 @@ config_inpainting = os.path.join(sd_configs_path, "v1-inpainting-inference.yaml" config_instruct_pix2pix = os.path.join(sd_configs_path, "instruct-pix2pix.yaml") config_alt_diffusion = os.path.join(sd_configs_path, "alt-diffusion-inference.yaml") -re_parametrization_v = re.compile(r'-v\b') + +def is_using_v_parameterization_for_sd2(state_dict): + """ + Detects whether unet in state_dict is using v-parameterization. Returns True if it is. You're welcome. + """ + + import ldm.modules.diffusionmodules.openaimodel + from modules import devices + + device = devices.cpu + + with sd_disable_initialization.DisableInitialization(): + unet = ldm.modules.diffusionmodules.openaimodel.UNetModel( + use_checkpoint=True, + use_fp16=False, + image_size=32, + in_channels=4, + out_channels=4, + model_channels=320, + attention_resolutions=[4, 2, 1], + num_res_blocks=2, + channel_mult=[1, 2, 4, 4], + num_head_channels=64, + use_spatial_transformer=True, + use_linear_in_transformer=True, + transformer_depth=1, + context_dim=1024, + legacy=False + ) + unet.eval() + + with torch.no_grad(): + unet_sd = {k.replace("model.diffusion_model.", ""): v for k, v in state_dict.items() if "model.diffusion_model." in k} + unet.load_state_dict(unet_sd, strict=True) + unet.to(device=device, dtype=torch.float) + + test_cond = torch.ones((1, 2, 1024), device=device) * 0.5 + x_test = torch.ones((1, 4, 8, 8), device=device) * 0.5 + + out = (unet(x_test, torch.asarray([999], device=device), context=test_cond) - x_test).mean().item() + + return out < -1 def guess_model_config_from_state_dict(sd, filename): - fn = os.path.basename(filename) - sd2_cond_proj_weight = sd.get('cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight', None) diffusion_model_input = sd.get('model.diffusion_model.input_blocks.0.0.weight', None) @@ -31,7 +72,7 @@ def guess_model_config_from_state_dict(sd, filename): if sd2_cond_proj_weight is not None and sd2_cond_proj_weight.shape[1] == 1024: if diffusion_model_input.shape[1] == 9: return config_sd2_inpainting - elif re.search(re_parametrization_v, fn): + elif is_using_v_parameterization_for_sd2(sd): return config_sd2v else: return config_sd2 From f8feeaaedb890de1e36eeb2ad387f0eb3abafd54 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 28 Jan 2023 15:57:56 +0300 Subject: [PATCH 102/151] add progressbar to extension update check; do not check for updates for disabled extensions --- javascript/extensions.js | 20 +++++++++++++++++--- modules/ui_extensions.py | 28 ++++++++++++++++++---------- 2 files changed, 35 insertions(+), 13 deletions(-) diff --git a/javascript/extensions.js b/javascript/extensions.js index ac6e35b9..c593cd2e 100644 --- a/javascript/extensions.js +++ b/javascript/extensions.js @@ -1,7 +1,8 @@ function extensions_apply(_, _){ - disable = [] - update = [] + var disable = [] + var update = [] + gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){ if(x.name.startsWith("enable_") && ! x.checked) disable.push(x.name.substr(7)) @@ -16,11 +17,24 @@ function extensions_apply(_, _){ } function extensions_check(){ + var disable = [] + + gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){ + if(x.name.startsWith("enable_") && ! x.checked) + disable.push(x.name.substr(7)) + }) + gradioApp().querySelectorAll('#extensions .extension_status').forEach(function(x){ x.innerHTML = "Loading..." }) - return [] + + var id = randomId() + requestProgress(id, gradioApp().getElementById('extensions_installed_top'), null, function(){ + + }) + + return [id, JSON.stringify(disable)] } function install_extension_from_index(button, url){ diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index 66a41865..37d30e1f 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -13,7 +13,7 @@ import shutil import errno from modules import extensions, shared, paths - +from modules.call_queue import wrap_gradio_gpu_call available_extensions = {"extensions": []} @@ -50,12 +50,17 @@ def apply_and_restart(disable_list, update_list): shared.state.need_restart = True -def check_updates(): +def check_updates(id_task, disable_list): check_access() - for ext in extensions.extensions: - if ext.remote is None: - continue + disabled = json.loads(disable_list) + assert type(disabled) == list, f"wrong disable_list data for apply_and_restart: {disable_list}" + + exts = [ext for ext in extensions.extensions if ext.remote is not None and ext.name not in disabled] + shared.state.job_count = len(exts) + + for ext in exts: + shared.state.textinfo = ext.name try: ext.check_updates() @@ -63,7 +68,9 @@ def check_updates(): print(f"Error checking updates for {ext.name}:", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) - return extension_table() + shared.state.nextjob() + + return extension_table(), "" def extension_table(): @@ -273,12 +280,13 @@ def create_ui(): with gr.Tabs(elem_id="tabs_extensions") as tabs: with gr.TabItem("Installed"): - with gr.Row(): + with gr.Row(elem_id="extensions_installed_top"): apply = gr.Button(value="Apply and restart UI", variant="primary") check = gr.Button(value="Check for updates") extensions_disabled_list = gr.Text(elem_id="extensions_disabled_list", visible=False).style(container=False) extensions_update_list = gr.Text(elem_id="extensions_update_list", visible=False).style(container=False) + info = gr.HTML() extensions_table = gr.HTML(lambda: extension_table()) apply.click( @@ -289,10 +297,10 @@ def create_ui(): ) check.click( - fn=check_updates, + fn=wrap_gradio_gpu_call(check_updates, extra_outputs=[gr.update()]), _js="extensions_check", - inputs=[], - outputs=[extensions_table], + inputs=[info, extensions_disabled_list], + outputs=[extensions_table, info], ) with gr.TabItem("Available"): From 5d14f282c2812888275902be4b552681f942dbfd Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 28 Jan 2023 16:23:49 +0300 Subject: [PATCH 103/151] fixed a bug where after switching to a checkpoint with unknown hash, you'd get empty space instead of checkpoint name in UI fixed a bug where if you update a selected checkpoint on disk and then restart the program, a different checkpoint loads, but the name is shown for the the old one. --- modules/sd_models.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index b2d48a51..c45ddf83 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -231,12 +231,10 @@ def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer): def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer): - title = checkpoint_info.title sd_model_hash = checkpoint_info.calculate_shorthash() timer.record("calculate hash") - if checkpoint_info.title != title: - shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title + shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title if state_dict is None: state_dict = get_checkpoint_state_dict(checkpoint_info, timer) From 1421e959600e0e9a2435e48373a551237bbab814 Mon Sep 17 00:00:00 2001 From: Thurion Date: Sat, 28 Jan 2023 14:42:24 +0100 Subject: [PATCH 104/151] allow empty mask dir --- modules/img2img.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/modules/img2img.py b/modules/img2img.py index fe9447c7..3ecb6146 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -21,8 +21,10 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args): images = shared.listfiles(input_dir) - inpaint_masks = shared.listfiles(inpaint_mask_dir) - is_inpaint_batch = inpaint_mask_dir and len(inpaint_masks) > 0 + is_inpaint_batch = False + if inpaint_mask_dir: + inpaint_masks = shared.listfiles(inpaint_mask_dir) + is_inpaint_batch = len(inpaint_masks) > 0 if is_inpaint_batch: print(f"\nInpaint batch is enabled. {len(inpaint_masks)} masks found.") From b7d2af8c7fa48d6eef7517a6fbc63a3507c638d4 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 28 Jan 2023 17:18:47 +0300 Subject: [PATCH 105/151] add dropdowns in settings for hypernets and loras --- extensions-builtin/Lora/extra_networks_lora.py | 8 +++++++- extensions-builtin/Lora/scripts/lora_script.py | 3 +++ modules/extra_networks_hypernet.py | 8 +++++++- modules/shared.py | 5 +++-- 4 files changed, 20 insertions(+), 4 deletions(-) diff --git a/extensions-builtin/Lora/extra_networks_lora.py b/extensions-builtin/Lora/extra_networks_lora.py index 8f2e753e..6be6ef73 100644 --- a/extensions-builtin/Lora/extra_networks_lora.py +++ b/extensions-builtin/Lora/extra_networks_lora.py @@ -1,4 +1,4 @@ -from modules import extra_networks +from modules import extra_networks, shared import lora class ExtraNetworkLora(extra_networks.ExtraNetwork): @@ -6,6 +6,12 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork): super().__init__('lora') def activate(self, p, params_list): + additional = shared.opts.sd_lora + + if additional != "" and additional in lora.available_loras and len([x for x in params_list if x.items[0] == additional]) == 0: + p.all_prompts = [x + f"" for x in p.all_prompts] + params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier])) + names = [] multipliers = [] for params in params_list: diff --git a/extensions-builtin/Lora/scripts/lora_script.py b/extensions-builtin/Lora/scripts/lora_script.py index 544b228d..2e860160 100644 --- a/extensions-builtin/Lora/scripts/lora_script.py +++ b/extensions-builtin/Lora/scripts/lora_script.py @@ -1,4 +1,5 @@ import torch +import gradio as gr import lora import extra_networks_lora @@ -31,5 +32,7 @@ script_callbacks.on_before_ui(before_ui) shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), { + "sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": [""] + [x for x in lora.available_loras]}, refresh=lora.list_available_loras), "lora_apply_to_outputs": shared.OptionInfo(False, "Apply Lora to outputs rather than inputs when possible (experimental)"), + })) diff --git a/modules/extra_networks_hypernet.py b/modules/extra_networks_hypernet.py index ff279a1f..d3a4d7ad 100644 --- a/modules/extra_networks_hypernet.py +++ b/modules/extra_networks_hypernet.py @@ -1,4 +1,4 @@ -from modules import extra_networks +from modules import extra_networks, shared, extra_networks from modules.hypernetworks import hypernetwork @@ -7,6 +7,12 @@ class ExtraNetworkHypernet(extra_networks.ExtraNetwork): super().__init__('hypernet') def activate(self, p, params_list): + additional = shared.opts.sd_hypernetwork + + if additional != "" and additional in shared.hypernetworks and len([x for x in params_list if x.items[0] == additional]) == 0: + p.all_prompts = [x + f"" for x in p.all_prompts] + params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier])) + names = [] multipliers = [] for params in params_list: diff --git a/modules/shared.py b/modules/shared.py index 474fcc42..eb04e811 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -405,7 +405,6 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"), "comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }), "CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}), - "extra_networks_default_multiplier": OptionInfo(1.0, "Multiplier for extra networks", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), "upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"), })) @@ -431,7 +430,9 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"), })) options_templates.update(options_section(('extra_networks', "Extra Networks"), { - "extra_networks_default_view": OptionInfo("cards", "Default view for Extra Networks", gr.Dropdown, { "choices": ["cards", "thumbs"] }), + "extra_networks_default_view": OptionInfo("cards", "Default view for Extra Networks", gr.Dropdown, {"choices": ["cards", "thumbs"]}), + "extra_networks_default_multiplier": OptionInfo(1.0, "Multiplier for extra networks", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), + "sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": [""] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks), })) options_templates.update(options_section(('ui', "User interface"), { From 591b68e56c53eed391d08ce008423191c573784d Mon Sep 17 00:00:00 2001 From: EllangoK Date: Sat, 28 Jan 2023 10:04:09 -0500 Subject: [PATCH 106/151] uses autos new regex, checks len of re_param --- modules/generation_parameters_copypaste.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index 13d0874d..53f1a865 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -11,7 +11,7 @@ from modules import shared, ui_tempdir, script_callbacks import tempfile from PIL import Image -re_param_code = r'\s*([\w ]+):\s*(\"[^\"]*\"|[^,]+)' +re_param_code = r'\s*([\w ]+):\s*("(?:\\"[^,]|\\"|\\|[^\"])+"|[^,]*)(?:,|$)' re_param = re.compile(re_param_code) re_imagesize = re.compile(r"^(\d+)x(\d+)$") re_hypernet_hash = re.compile("\(([0-9a-f]+)\)$") @@ -242,7 +242,7 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model done_with_prompt = False *lines, lastline = x.strip().split("\n") - if not re_param.match(lastline): + if len(re_param.findall(lastline)) < 3: lines.append(lastline) lastline = '' From f4eeff659e18fc7683f426371394f48b58a00bd3 Mon Sep 17 00:00:00 2001 From: ItsOlegDm Date: Sat, 28 Jan 2023 17:05:08 +0200 Subject: [PATCH 107/151] Removed buttons centering --- style.css | 1 - 1 file changed, 1 deletion(-) diff --git a/style.css b/style.css index 1e90b105..3cbabfd6 100644 --- a/style.css +++ b/style.css @@ -136,7 +136,6 @@ #txt2img_tools, #img2img_tools{ gap: 0.4em; - justify-content: center; } #interrogate_col{ From 1e22f48f4dbef15d8b2ba353b6c3cd68c4d0b42e Mon Sep 17 00:00:00 2001 From: ItsOlegDm Date: Sat, 28 Jan 2023 17:08:38 +0200 Subject: [PATCH 108/151] img2img styled padding fix --- style.css | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/style.css b/style.css index 3cbabfd6..a4bab9be 100644 --- a/style.css +++ b/style.css @@ -156,14 +156,13 @@ #txt2img_styles_row, #img2img_styles_row{ gap: 0.25em; + margin-top: 0.3em; } #txt2img_styles_row > button, #img2img_styles_row > button{ margin: 0; } -#txt2img_styles_row { - margin-top: 0.3em; -} + #txt2img_styles, #img2img_styles{ padding: 0; } From e2c71a4bd41470b9503021db36be2ae65f345d97 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 28 Jan 2023 18:12:53 +0300 Subject: [PATCH 109/151] make prevent the browser from using cached version of scripts when they change --- modules/ui.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/ui.py b/modules/ui.py index 9f4cfda1..4e082408 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1692,14 +1692,14 @@ def create_ui(): def reload_javascript(): - head = f'\n' + head = f'\n' inline = f"{localization.localization_js(shared.opts.localization)};" if cmd_opts.theme is not None: inline += f"set_theme('{cmd_opts.theme}');" for script in modules.scripts.list_scripts("javascript", ".js"): - head += f'\n' + head += f'\n' head += f'\n' From 29d2d6a094a1b4028b8d281f069f28bd4cacc944 Mon Sep 17 00:00:00 2001 From: ItsOlegDm Date: Sat, 28 Jan 2023 17:21:59 +0200 Subject: [PATCH 110/151] Train tab fix --- style.css | 1 - 1 file changed, 1 deletion(-) diff --git a/style.css b/style.css index a4bab9be..39312c89 100644 --- a/style.css +++ b/style.css @@ -765,7 +765,6 @@ footer { .dark .gr-compact{ background-color: rgb(31 41 55 / var(--tw-bg-opacity)); - align-items: center; margin-left: 0; } From 1d8e06d542176beade37d2d36cb57460c3c6772f Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 28 Jan 2023 22:52:27 +0300 Subject: [PATCH 111/151] add checkpoints tab for extra networks UI --- .../Lora/ui_extra_networks_lora.py | 2 +- javascript/ui.js | 7 ++++ modules/ui.py | 8 ++++ modules/ui_extra_networks.py | 37 ++++++++++++++++-- modules/ui_extra_networks_checkpoints.py | 38 +++++++++++++++++++ modules/ui_extra_networks_hypernets.py | 2 +- .../ui_extra_networks_textual_inversion.py | 2 +- webui.py | 6 ++- 8 files changed, 94 insertions(+), 8 deletions(-) create mode 100644 modules/ui_extra_networks_checkpoints.py diff --git a/extensions-builtin/Lora/ui_extra_networks_lora.py b/extensions-builtin/Lora/ui_extra_networks_lora.py index 54a80d36..c1244b10 100644 --- a/extensions-builtin/Lora/ui_extra_networks_lora.py +++ b/extensions-builtin/Lora/ui_extra_networks_lora.py @@ -20,7 +20,7 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): preview = None for file in previews: if os.path.isfile(file): - preview = "./file=" + file.replace('\\', '/') + "?mtime=" + str(os.path.getmtime(file)) + preview = self.link_preview(file) break yield { diff --git a/javascript/ui.js b/javascript/ui.js index ba72623c..dd40e62d 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -309,3 +309,10 @@ function updateInput(target){ Object.defineProperty(e, "target", {value: target}) target.dispatchEvent(e); } + + +var desiredCheckpointName = null; +function selectCheckpoint(name){ + desiredCheckpointName = name; + gradioApp().getElementById('change_checkpoint').click() +} diff --git a/modules/ui.py b/modules/ui.py index 4e082408..f1195692 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1560,6 +1560,14 @@ def create_ui(): outputs=[component, text_settings], ) + button_set_checkpoint = gr.Button('Change checkpoint', elem_id='change_checkpoint', visible=False) + button_set_checkpoint.click( + fn=lambda value, _: run_settings_single(value, key='sd_model_checkpoint'), + _js="function(v){ var res = desiredCheckpointName; desiredCheckpointName = ''; return [res || v, null]; }", + inputs=[component_dict['sd_model_checkpoint'], dummy_component], + outputs=[component_dict['sd_model_checkpoint'], text_settings], + ) + component_keys = [k for k in opts.data_labels.keys() if k in component_dict] def get_settings_values(): diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index c6ff889a..5730c879 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -1,4 +1,6 @@ import os.path +import urllib.parse +from pathlib import Path from modules import shared import gradio as gr @@ -8,12 +10,31 @@ import html from modules.generation_parameters_copypaste import image_from_url_text extra_pages = [] +allowed_dirs = set() def register_page(page): """registers extra networks page for the UI; recommend doing it in on_before_ui() callback for extensions""" extra_pages.append(page) + allowed_dirs.clear() + allowed_dirs.update(set(sum([x.allowed_directories_for_previews() for x in extra_pages], []))) + + +def add_pages_to_demo(app): + def fetch_file(filename: str = ""): + from starlette.responses import FileResponse + + if not any([Path(x).resolve() in Path(filename).resolve().parents for x in allowed_dirs]): + raise ValueError(f"File cannot be fetched: {filename}. Must be in one of directories registered by extra pages.") + + if os.path.splitext(filename)[1].lower() != ".png": + raise ValueError(f"File cannot be fetched: {filename}. Only png.") + + # would profit from returning 304 + return FileResponse(filename, headers={"Accept-Ranges": "bytes"}) + + app.add_api_route("/sd_extra_networks/thumb", fetch_file, methods=["GET"]) class ExtraNetworksPage: @@ -26,6 +47,9 @@ class ExtraNetworksPage: def refresh(self): pass + def link_preview(self, filename): + return "./sd_extra_networks/thumb?filename=" + urllib.parse.quote(filename.replace('\\', '/')) + "&mtime=" + str(os.path.getmtime(filename)) + def create_html(self, tabname): view = shared.opts.extra_networks_default_view items_html = '' @@ -54,13 +78,17 @@ class ExtraNetworksPage: def create_html_for_item(self, item, tabname): preview = item.get("preview", None) + onclick = item.get("onclick", None) + if onclick is None: + onclick = '"' + html.escape(f"""return cardClicked({json.dumps(tabname)}, {item["prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"' + args = { "preview_html": "style='background-image: url(\"" + html.escape(preview) + "\")'" if preview else '', - "prompt": item["prompt"], + "prompt": item.get("prompt", None), "tabname": json.dumps(tabname), "local_preview": json.dumps(item["local_preview"]), "name": item["name"], - "card_clicked": '"' + html.escape(f"""return cardClicked({json.dumps(tabname)}, {item["prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"', + "card_clicked": onclick, "save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {json.dumps(tabname)}, {json.dumps(item["local_preview"])})""") + '"', } @@ -143,7 +171,7 @@ def path_is_parent(parent_path, child_path): parent_path = os.path.abspath(parent_path) child_path = os.path.abspath(child_path) - return os.path.commonpath([parent_path]) == os.path.commonpath([parent_path, child_path]) + return child_path.startswith(parent_path) def setup_ui(ui, gallery): @@ -173,7 +201,8 @@ def setup_ui(ui, gallery): ui.button_save_preview.click( fn=save_preview, - _js="function(x, y, z){console.log(x, y, z); return [selected_gallery_index(), y, z]}", + _js="function(x, y, z){return [selected_gallery_index(), y, z]}", inputs=[ui.preview_target_filename, gallery, ui.preview_target_filename], outputs=[*ui.pages] ) + diff --git a/modules/ui_extra_networks_checkpoints.py b/modules/ui_extra_networks_checkpoints.py new file mode 100644 index 00000000..c66cb830 --- /dev/null +++ b/modules/ui_extra_networks_checkpoints.py @@ -0,0 +1,38 @@ +import html +import json +import os +import urllib.parse + +from modules import shared, ui_extra_networks, sd_models + + +class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage): + def __init__(self): + super().__init__('Checkpoints') + + def refresh(self): + shared.refresh_checkpoints() + + def list_items(self): + for name, checkpoint1 in sd_models.checkpoints_list.items(): + checkpoint: sd_models.CheckpointInfo = checkpoint1 + path, ext = os.path.splitext(checkpoint.filename) + previews = [path + ".png", path + ".preview.png"] + + preview = None + for file in previews: + if os.path.isfile(file): + preview = self.link_preview(file) + break + + yield { + "name": checkpoint.model_name, + "filename": path, + "preview": preview, + "onclick": '"' + html.escape(f"""return selectCheckpoint({json.dumps(name)})""") + '"', + "local_preview": path + ".png", + } + + def allowed_directories_for_previews(self): + return [shared.cmd_opts.ckpt_dir, sd_models.model_path] + diff --git a/modules/ui_extra_networks_hypernets.py b/modules/ui_extra_networks_hypernets.py index 65d000cf..8c15f8eb 100644 --- a/modules/ui_extra_networks_hypernets.py +++ b/modules/ui_extra_networks_hypernets.py @@ -19,7 +19,7 @@ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage): preview = None for file in previews: if os.path.isfile(file): - preview = "./file=" + file.replace('\\', '/') + "?mtime=" + str(os.path.getmtime(file)) + preview = self.link_preview(file) break yield { diff --git a/modules/ui_extra_networks_textual_inversion.py b/modules/ui_extra_networks_textual_inversion.py index dbd23d2d..a9d3064b 100644 --- a/modules/ui_extra_networks_textual_inversion.py +++ b/modules/ui_extra_networks_textual_inversion.py @@ -19,7 +19,7 @@ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage): preview = None if os.path.isfile(preview_file): - preview = "./file=" + preview_file.replace('\\', '/') + "?mtime=" + str(os.path.getmtime(preview_file)) + preview = self.link_preview(preview_file) yield { "name": embedding.name, diff --git a/webui.py b/webui.py index 41f32f5c..0d0b8364 100644 --- a/webui.py +++ b/webui.py @@ -12,7 +12,7 @@ from packaging import version import logging logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage()) -from modules import import_hook, errors, extra_networks +from modules import import_hook, errors, extra_networks, ui_extra_networks_checkpoints from modules import extra_networks_hypernet, ui_extra_networks_hypernets, ui_extra_networks_textual_inversion from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call @@ -119,6 +119,7 @@ def initialize(): ui_extra_networks.intialize() ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion()) ui_extra_networks.register_page(ui_extra_networks_hypernets.ExtraNetworksPageHypernetworks()) + ui_extra_networks.register_page(ui_extra_networks_checkpoints.ExtraNetworksPageCheckpoints()) extra_networks.initialize() extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet()) @@ -227,6 +228,8 @@ def webui(): if launch_api: create_api(app) + ui_extra_networks.add_pages_to_demo(app) + modules.script_callbacks.app_started_callback(shared.demo, app) wait_on_server(shared.demo) @@ -254,6 +257,7 @@ def webui(): ui_extra_networks.intialize() ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion()) ui_extra_networks.register_page(ui_extra_networks_hypernets.ExtraNetworksPageHypernetworks()) + ui_extra_networks.register_page(ui_extra_networks_checkpoints.ExtraNetworksPageCheckpoints()) extra_networks.initialize() extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet()) From 0a8515085ef258d4b76fdc000f7ed9d55751d6b8 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 28 Jan 2023 23:31:48 +0300 Subject: [PATCH 112/151] make it so that clicking on hypernet/lora card one more time removes the related from the prompt --- javascript/extraNetworks.js | 31 ++++++++++++++++++++++++++++++- 1 file changed, 30 insertions(+), 1 deletion(-) diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index c5a9adb3..b5536a34 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -48,10 +48,39 @@ function setupExtraNetworks(){ onUiLoaded(setupExtraNetworks) +var re_extranet = /<([^:]+:[^:]+):[\d\.]+>/; +var re_extranet_g = /\s+<([^:]+:[^:]+):[\d\.]+>/g; + +function tryToRemoveExtraNetworkFromPrompt(textarea, text){ + var m = text.match(re_extranet) + if(! m) return false + + var partToSearch = m[1] + var replaced = false + var newTextareaText = textarea.value.replaceAll(re_extranet_g, function(found, index){ + m = found.match(re_extranet); + if(m[1] == partToSearch){ + replaced = true; + return "" + } + return found; + }) + + if(replaced){ + textarea.value = newTextareaText + return true; + } + + return false +} + function cardClicked(tabname, textToAdd, allowNegativePrompt){ var textarea = allowNegativePrompt ? activePromptTextarea[tabname] : gradioApp().querySelector("#" + tabname + "_prompt > label > textarea") - textarea.value = textarea.value + " " + textToAdd + if(! tryToRemoveExtraNetworkFromPrompt(textarea, textToAdd)){ + textarea.value = textarea.value + " " + textToAdd + } + updateInput(textarea) } From fb58fa62400ba0e2cbe8703796d5573f5f91c398 Mon Sep 17 00:00:00 2001 From: EllangoK Date: Sat, 28 Jan 2023 15:37:01 -0500 Subject: [PATCH 113/151] xyz plot now saves sub grids if opts.grid_save also fixed no draw legend for z grid --- scripts/xyz_grid.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py index 3df40483..3122f6f6 100644 --- a/scripts/xyz_grid.py +++ b/scripts/xyz_grid.py @@ -286,23 +286,24 @@ def draw_xyz_grid(p, xs, ys, zs, x_labels, y_labels, z_labels, cell, draw_legend print("Unexpected error: draw_xyz_grid failed to return even a single processed image") return Processed(p, []) - grids = [None] * len(zs) + sub_grids = [None] * len(zs) for i in range(len(zs)): start_index = i * len(xs) * len(ys) end_index = start_index + len(xs) * len(ys) grid = images.image_grid(image_cache[start_index:end_index], rows=len(ys)) if draw_legend: grid = images.draw_grid_annotations(grid, cell_size[0], cell_size[1], hor_texts, ver_texts) - - grids[i] = grid + sub_grids[i] = grid if include_sub_grids and len(zs) > 1: processed_result.images.insert(i+1, grid) - original_grid_size = grids[0].size - grids = images.image_grid(grids, rows=1) - processed_result.images[0] = images.draw_grid_annotations(grids, original_grid_size[0], original_grid_size[1], title_texts, [[images.GridAnnotation()]]) + sub_grid_size = sub_grids[0].size + z_grid = images.image_grid(sub_grids, rows=1) + if draw_legend: + z_grid = images.draw_grid_annotations(z_grid, sub_grid_size[0], sub_grid_size[1], title_texts, [[images.GridAnnotation()]]) + processed_result.images[0] = z_grid - return processed_result + return processed_result, sub_grids class SharedSettingsStackHelper(object): @@ -576,7 +577,7 @@ class Script(scripts.Script): return res with SharedSettingsStackHelper(): - processed = draw_xyz_grid( + processed, sub_grids = draw_xyz_grid( p, xs=xs, ys=ys, @@ -592,6 +593,10 @@ class Script(scripts.Script): second_axes_processed=second_axes_processed ) + if opts.grid_save and len(sub_grids) > 1: + for sub_grid in sub_grids: + images.save_image(sub_grid, p.outpath_grids, "xyz_grid", info=grid_infotext[0], extension=opts.grid_format, prompt=p.prompt, seed=processed.seed, grid=True, p=p) + if opts.grid_save: images.save_image(processed.images[0], p.outpath_grids, "xyz_grid", info=grid_infotext[0], extension=opts.grid_format, prompt=p.prompt, seed=processed.seed, grid=True, p=p) From 09a142a05a6da8bdd4f36678a098c2a573db181a Mon Sep 17 00:00:00 2001 From: glop102 Date: Sat, 28 Jan 2023 19:25:52 -0500 Subject: [PATCH 114/151] Reduce grid rows if larger than number of images available When a set number of grid rows is specified in settings, then it leads to situations where an entire row in the grid is empty. The most noticable example is the processing preview when the row count is set to 2, where it shows the preview just fine but with a black rectangle under it. --- modules/images.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/modules/images.py b/modules/images.py index 0bc3d524..ae3cdaf4 100644 --- a/modules/images.py +++ b/modules/images.py @@ -36,6 +36,8 @@ def image_grid(imgs, batch_size=1, rows=None): else: rows = math.sqrt(len(imgs)) rows = round(rows) + if rows > len(imgs): + rows = len(imgs) cols = math.ceil(len(imgs) / rows) From f6b7768f84a335d351ba8c0d4c34d78e59272339 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 29 Jan 2023 10:20:19 +0300 Subject: [PATCH 115/151] support for searching subdirectory names for extra networks --- extensions-builtin/Lora/ui_extra_networks_lora.py | 1 + html/extra-networks-card.html | 1 + javascript/extraNetworks.js | 2 +- modules/sd_models.py | 1 + modules/ui_extra_networks.py | 11 +++++++++++ modules/ui_extra_networks_checkpoints.py | 6 +++--- modules/ui_extra_networks_hypernets.py | 1 + modules/ui_extra_networks_textual_inversion.py | 1 + 8 files changed, 20 insertions(+), 4 deletions(-) diff --git a/extensions-builtin/Lora/ui_extra_networks_lora.py b/extensions-builtin/Lora/ui_extra_networks_lora.py index c1244b10..22cabcb0 100644 --- a/extensions-builtin/Lora/ui_extra_networks_lora.py +++ b/extensions-builtin/Lora/ui_extra_networks_lora.py @@ -27,6 +27,7 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): "name": name, "filename": path, "preview": preview, + "search_term": self.search_terms_from_path(lora_on_disk.filename), "prompt": json.dumps(f""), "local_preview": path + ".png", } diff --git a/html/extra-networks-card.html b/html/extra-networks-card.html index aa9fca87..8a5e2fbd 100644 --- a/html/extra-networks-card.html +++ b/html/extra-networks-card.html @@ -4,6 +4,7 @@ +
{name}
diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index b5536a34..231fafe5 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -16,7 +16,7 @@ function setupExtraNetworksForTab(tabname){ searchTerm = search.value.toLowerCase() gradioApp().querySelectorAll('#'+tabname+'_extra_tabs div.card').forEach(function(elem){ - text = elem.querySelector('.name').textContent.toLowerCase() + text = elem.querySelector('.name').textContent.toLowerCase() + " " + elem.querySelector('.search_term').textContent.toLowerCase() elem.style.display = text.indexOf(searchTerm) == -1 ? "none" : "" }) }); diff --git a/modules/sd_models.py b/modules/sd_models.py index c45ddf83..300387a9 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -41,6 +41,7 @@ class CheckpointInfo: name = name[1:] self.name = name + self.name_for_extra = os.path.splitext(os.path.basename(filename))[0] self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0] self.hash = model_hash(filename) diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 5730c879..29c6e196 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -50,6 +50,16 @@ class ExtraNetworksPage: def link_preview(self, filename): return "./sd_extra_networks/thumb?filename=" + urllib.parse.quote(filename.replace('\\', '/')) + "&mtime=" + str(os.path.getmtime(filename)) + def search_terms_from_path(self, filename, possible_directories=None): + abspath = os.path.abspath(filename) + + for parentdir in (possible_directories if possible_directories is not None else self.allowed_directories_for_previews()): + parentdir = os.path.abspath(parentdir) + if abspath.startswith(parentdir): + return abspath[len(parentdir):].replace('\\','/') + + return "" + def create_html(self, tabname): view = shared.opts.extra_networks_default_view items_html = '' @@ -90,6 +100,7 @@ class ExtraNetworksPage: "name": item["name"], "card_clicked": onclick, "save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {json.dumps(tabname)}, {json.dumps(item["local_preview"])})""") + '"', + "search_term": item.get("search_term", ""), } return self.card_page.format(**args) diff --git a/modules/ui_extra_networks_checkpoints.py b/modules/ui_extra_networks_checkpoints.py index c66cb830..360579b0 100644 --- a/modules/ui_extra_networks_checkpoints.py +++ b/modules/ui_extra_networks_checkpoints.py @@ -14,8 +14,7 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage): shared.refresh_checkpoints() def list_items(self): - for name, checkpoint1 in sd_models.checkpoints_list.items(): - checkpoint: sd_models.CheckpointInfo = checkpoint1 + for name, checkpoint in sd_models.checkpoints_list.items(): path, ext = os.path.splitext(checkpoint.filename) previews = [path + ".png", path + ".preview.png"] @@ -26,9 +25,10 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage): break yield { - "name": checkpoint.model_name, + "name": checkpoint.name_for_extra, "filename": path, "preview": preview, + "search_term": self.search_terms_from_path(checkpoint.filename), "onclick": '"' + html.escape(f"""return selectCheckpoint({json.dumps(name)})""") + '"', "local_preview": path + ".png", } diff --git a/modules/ui_extra_networks_hypernets.py b/modules/ui_extra_networks_hypernets.py index 8c15f8eb..57851088 100644 --- a/modules/ui_extra_networks_hypernets.py +++ b/modules/ui_extra_networks_hypernets.py @@ -26,6 +26,7 @@ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage): "name": name, "filename": path, "preview": preview, + "search_term": self.search_terms_from_path(path), "prompt": json.dumps(f""), "local_preview": path + ".png", } diff --git a/modules/ui_extra_networks_textual_inversion.py b/modules/ui_extra_networks_textual_inversion.py index a9d3064b..bb64eb81 100644 --- a/modules/ui_extra_networks_textual_inversion.py +++ b/modules/ui_extra_networks_textual_inversion.py @@ -25,6 +25,7 @@ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage): "name": embedding.name, "filename": embedding.filename, "preview": preview, + "search_term": self.search_terms_from_path(embedding.filename), "prompt": json.dumps(embedding.name), "local_preview": path + ".preview.png", } From 659d602dce42608a664642021ea2441da045d189 Mon Sep 17 00:00:00 2001 From: EllangoK Date: Sun, 29 Jan 2023 02:32:53 -0500 Subject: [PATCH 116/151] only returns ckpt directories if they are not none --- modules/ui_extra_networks_checkpoints.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/ui_extra_networks_checkpoints.py b/modules/ui_extra_networks_checkpoints.py index c66cb830..5b471671 100644 --- a/modules/ui_extra_networks_checkpoints.py +++ b/modules/ui_extra_networks_checkpoints.py @@ -34,5 +34,5 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage): } def allowed_directories_for_previews(self): - return [shared.cmd_opts.ckpt_dir, sd_models.model_path] + return [v for v in [shared.cmd_opts.ckpt_dir, sd_models.model_path] if v is not None] From 8d7382ab24756cdcc37e71406832814f4713c55e Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 29 Jan 2023 11:34:58 +0300 Subject: [PATCH 117/151] add buttons for auto-search in subdirectories for extra tabs --- javascript/extraNetworks.js | 9 +++++++++ modules/ui_extra_networks.py | 27 ++++++++++++++++++++++++++- style.css | 6 ++++++ 3 files changed, 41 insertions(+), 1 deletion(-) diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index 231fafe5..17bf2000 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -96,3 +96,12 @@ function saveCardPreview(event, tabname, filename){ event.stopPropagation() event.preventDefault() } + +function extraNetworksSearchButton(tabs_id, event){ + searchTextarea = gradioApp().querySelector("#" + tabs_id + ' > div > textarea') + button = event.target + text = button.classList.contains("search-all") ? "" : button.textContent.trim() + + searchTextarea.value = text + updateInput(searchTextarea) +} \ No newline at end of file diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 29c6e196..83367968 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -1,3 +1,4 @@ +import glob import os.path import urllib.parse from pathlib import Path @@ -56,7 +57,7 @@ class ExtraNetworksPage: for parentdir in (possible_directories if possible_directories is not None else self.allowed_directories_for_previews()): parentdir = os.path.abspath(parentdir) if abspath.startswith(parentdir): - return abspath[len(parentdir):].replace('\\','/') + return abspath[len(parentdir):].replace('\\', '/') return "" @@ -64,6 +65,27 @@ class ExtraNetworksPage: view = shared.opts.extra_networks_default_view items_html = '' + subdirs = {} + for parentdir in [os.path.abspath(x) for x in self.allowed_directories_for_previews()]: + for x in glob.glob(os.path.join(parentdir, '**/*'), recursive=True): + if not os.path.isdir(x): + continue + + subdir = os.path.abspath(x)[len(parentdir):].replace("\\", "/") + while subdir.startswith("/"): + subdir = subdir[1:] + + subdirs[subdir] = 1 + + if subdirs: + subdirs = {"": 1, **subdirs} + + subdirs_html = "".join([f""" + +""" for subdir in subdirs]) + for item in self.list_items(): items_html += self.create_html_for_item(item, tabname) @@ -72,6 +94,9 @@ class ExtraNetworksPage: items_html = shared.html("extra-networks-no-cards.html").format(dirs=dirs) res = f""" +
+{subdirs_html} +
{items_html}
diff --git a/style.css b/style.css index 39312c89..05572f66 100644 --- a/style.css +++ b/style.css @@ -807,7 +807,13 @@ footer { margin: 0.3em; } +.extra-network-subdirs{ + padding: 0.2em 0.35em; +} +.extra-network-subdirs button{ + margin: 0 0.15em; +} #txt2img_extra_networks .search, #img2img_extra_networks .search{ display: inline-block; From 920fe8057cb325e9835f70c0389499c51cbdd3b5 Mon Sep 17 00:00:00 2001 From: EllangoK Date: Sun, 29 Jan 2023 03:36:16 -0500 Subject: [PATCH 118/151] fixes #7284 btn unbound error --- modules/ui.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/ui.py b/modules/ui.py index f1195692..7e193240 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -466,8 +466,8 @@ def create_ui(): width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="txt2img_width") height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="txt2img_height") + res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="txt2img_res_switch_btn") if opts.dimensions_and_batch_together: - res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="txt2img_res_switch_btn") with gr.Column(elem_id="txt2img_column_batch"): batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count") batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size") @@ -737,8 +737,8 @@ def create_ui(): width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width") height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="img2img_height") + res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="img2img_res_switch_btn") if opts.dimensions_and_batch_together: - res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="img2img_res_switch_btn") with gr.Column(elem_id="img2img_column_batch"): batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count") batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size") From aa6e55e00140da6d73d3d360a5628c1b1316550d Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 29 Jan 2023 11:53:05 +0300 Subject: [PATCH 119/151] do not display the message for TI unless the list of loaded embeddings changed --- modules/textual_inversion/textual_inversion.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 6cf00e65..a1a406c2 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -112,6 +112,7 @@ class EmbeddingDatabase: self.skipped_embeddings = {} self.expected_shape = -1 self.embedding_dirs = {} + self.previously_displayed_embeddings = () def add_embedding_dir(self, path): self.embedding_dirs[path] = DirWithTextualInversionEmbeddings(path) @@ -228,9 +229,12 @@ class EmbeddingDatabase: self.load_from_dir(embdir) embdir.update() - print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}") - if len(self.skipped_embeddings) > 0: - print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings.keys())}") + displayed_embeddings = (tuple(self.word_embeddings.keys()), tuple(self.skipped_embeddings.keys())) + if self.previously_displayed_embeddings != displayed_embeddings: + self.previously_displayed_embeddings = displayed_embeddings + print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}") + if len(self.skipped_embeddings) > 0: + print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings.keys())}") def find_embedding_at_position(self, tokens, offset): token = tokens[offset] From 00dab8f10defbbda579a1bc89c8d4e972c58a20d Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 29 Jan 2023 11:53:24 +0300 Subject: [PATCH 120/151] remove Batch size and Batch pos from textinfo (goodbye) --- modules/processing.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index afab6790..2d295932 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -450,8 +450,6 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter "Size": f"{p.width}x{p.height}", "Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash), "Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')), - "Batch size": (None if p.batch_size < 2 else p.batch_size), - "Batch pos": (None if p.batch_size < 2 else position_in_batch), "Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]), "Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength), "Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"), From 7c53f81caf817a7e7dc9c2fafebfcce269ecb1d7 Mon Sep 17 00:00:00 2001 From: Yevhenii Hurin Date: Sun, 29 Jan 2023 15:29:03 +0200 Subject: [PATCH 121/151] Prompt selector for Prompt Matrix script --- scripts/prompt_matrix.py | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/scripts/prompt_matrix.py b/scripts/prompt_matrix.py index dd95e588..702870ce 100644 --- a/scripts/prompt_matrix.py +++ b/scripts/prompt_matrix.py @@ -10,7 +10,7 @@ from modules import images from modules.processing import process_images, Processed from modules.shared import opts, cmd_opts, state import modules.sd_samplers - +from pprint import pprint def draw_xy_grid(xs, ys, x_label, y_label, cell): res = [] @@ -44,16 +44,23 @@ class Script(scripts.Script): def title(self): return "Prompt matrix" - def ui(self, is_img2img): + def ui(self, is_img2img): put_at_start = gr.Checkbox(label='Put variable parts at start of prompt', value=False, elem_id=self.elem_id("put_at_start")) different_seeds = gr.Checkbox(label='Use different seed for each picture', value=False, elem_id=self.elem_id("different_seeds")) + # Radio buttons for selecting the prompt between positive and negative + prompt_type = gr.Radio(["positive", "negative"], label="Select prompt", elem_id=self.elem_id("prompt_type"), value="positive") - return [put_at_start, different_seeds] + return [put_at_start, different_seeds, prompt_type] - def run(self, p, put_at_start, different_seeds): + def run(self, p, put_at_start, different_seeds, prompt_type): modules.processing.fix_seed(p) + # Raise error if promp type is not positive or negative + if prompt_type not in ["positive", "negative"]: + raise ValueError(f"Unknown prompt type {prompt_type}") - original_prompt = p.prompt[0] if type(p.prompt) == list else p.prompt + prompt = p.prompt if prompt_type == "positive" else p.negative_prompt + original_prompt = prompt[0] if type(prompt) == list else prompt + positive_prompt = p.prompt[0] if type(p.prompt) == list else p.prompt all_prompts = [] prompt_matrix_parts = original_prompt.split("|") @@ -73,9 +80,12 @@ class Script(scripts.Script): print(f"Prompt matrix will create {len(all_prompts)} images using a total of {p.n_iter} batches.") - p.prompt = all_prompts + if prompt_type == "positive": + p.prompt = all_prompts + else: + p.negative_prompt = all_prompts p.seed = [p.seed + (i if different_seeds else 0) for i in range(len(all_prompts))] - p.prompt_for_display = original_prompt + p.prompt_for_display = positive_prompt processed = process_images(p) grid = images.image_grid(processed.images, p.batch_size, rows=1 << ((len(prompt_matrix_parts) - 1) // 2)) From edabd927296d389856ffc87c7789111a21a051a2 Mon Sep 17 00:00:00 2001 From: Yevhenii Hurin Date: Sun, 29 Jan 2023 16:05:59 +0200 Subject: [PATCH 122/151] Add delimiter selector to the Prompt Matrix script --- scripts/prompt_matrix.py | 47 ++++++++++++++++++++++++++++------------ 1 file changed, 33 insertions(+), 14 deletions(-) diff --git a/scripts/prompt_matrix.py b/scripts/prompt_matrix.py index 702870ce..89db9e63 100644 --- a/scripts/prompt_matrix.py +++ b/scripts/prompt_matrix.py @@ -12,6 +12,7 @@ from modules.shared import opts, cmd_opts, state import modules.sd_samplers from pprint import pprint + def draw_xy_grid(xs, ys, x_label, y_label, cell): res = [] @@ -33,7 +34,8 @@ def draw_xy_grid(xs, ys, x_label, y_label, cell): res.append(processed.images[0]) grid = images.image_grid(res, rows=len(ys)) - grid = images.draw_grid_annotations(grid, res[0].width, res[0].height, hor_texts, ver_texts) + grid = images.draw_grid_annotations( + grid, res[0].width, res[0].height, hor_texts, ver_texts) first_processed.images = [grid] @@ -45,56 +47,73 @@ class Script(scripts.Script): return "Prompt matrix" def ui(self, is_img2img): - put_at_start = gr.Checkbox(label='Put variable parts at start of prompt', value=False, elem_id=self.elem_id("put_at_start")) - different_seeds = gr.Checkbox(label='Use different seed for each picture', value=False, elem_id=self.elem_id("different_seeds")) + put_at_start = gr.Checkbox(label='Put variable parts at start of prompt', + value=False, elem_id=self.elem_id("put_at_start")) + different_seeds = gr.Checkbox( + label='Use different seed for each picture', value=False, elem_id=self.elem_id("different_seeds")) # Radio buttons for selecting the prompt between positive and negative - prompt_type = gr.Radio(["positive", "negative"], label="Select prompt", elem_id=self.elem_id("prompt_type"), value="positive") + prompt_type = gr.Radio(["positive", "negative"], label="Select prompt", + elem_id=self.elem_id("prompt_type"), value="positive") + # Radio buttons for selecting the delimiter to use in the resulting prompt + variations_delimiter = gr.Radio(["comma", "space"], label="Select delimiter", elem_id=self.elem_id( + "variations_delimiter"), value="comma") + return [put_at_start, different_seeds, prompt_type, variations_delimiter] - return [put_at_start, different_seeds, prompt_type] - - def run(self, p, put_at_start, different_seeds, prompt_type): + def run(self, p, put_at_start, different_seeds, prompt_type, variations_delimiter): modules.processing.fix_seed(p) # Raise error if promp type is not positive or negative if prompt_type not in ["positive", "negative"]: raise ValueError(f"Unknown prompt type {prompt_type}") + # Raise error if variations delimiter is not comma or space + if variations_delimiter not in ["comma", "space"]: + raise ValueError( + f"Unknown variations delimiter {variations_delimiter}") prompt = p.prompt if prompt_type == "positive" else p.negative_prompt original_prompt = prompt[0] if type(prompt) == list else prompt positive_prompt = p.prompt[0] if type(p.prompt) == list else p.prompt + delimiter = ", " if variations_delimiter == "comma" else " " + all_prompts = [] prompt_matrix_parts = original_prompt.split("|") combination_count = 2 ** (len(prompt_matrix_parts) - 1) for combination_num in range(combination_count): - selected_prompts = [text.strip().strip(',') for n, text in enumerate(prompt_matrix_parts[1:]) if combination_num & (1 << n)] + selected_prompts = [text.strip().strip(',') for n, text in enumerate( + prompt_matrix_parts[1:]) if combination_num & (1 << n)] if put_at_start: selected_prompts = selected_prompts + [prompt_matrix_parts[0]] else: selected_prompts = [prompt_matrix_parts[0]] + selected_prompts - all_prompts.append(", ".join(selected_prompts)) + all_prompts.append(delimiter.join(selected_prompts)) p.n_iter = math.ceil(len(all_prompts) / p.batch_size) p.do_not_save_grid = True - print(f"Prompt matrix will create {len(all_prompts)} images using a total of {p.n_iter} batches.") + print( + f"Prompt matrix will create {len(all_prompts)} images using a total of {p.n_iter} batches.") if prompt_type == "positive": p.prompt = all_prompts else: p.negative_prompt = all_prompts - p.seed = [p.seed + (i if different_seeds else 0) for i in range(len(all_prompts))] + p.seed = [p.seed + (i if different_seeds else 0) + for i in range(len(all_prompts))] p.prompt_for_display = positive_prompt processed = process_images(p) - grid = images.image_grid(processed.images, p.batch_size, rows=1 << ((len(prompt_matrix_parts) - 1) // 2)) - grid = images.draw_prompt_matrix(grid, p.width, p.height, prompt_matrix_parts) + grid = images.image_grid(processed.images, p.batch_size, rows=1 << ( + (len(prompt_matrix_parts) - 1) // 2)) + grid = images.draw_prompt_matrix( + grid, p.width, p.height, prompt_matrix_parts) processed.images.insert(0, grid) processed.index_of_first_image = 1 processed.infotexts.insert(0, processed.infotexts[0]) if opts.grid_save: - images.save_image(processed.images[0], p.outpath_grids, "prompt_matrix", extension=opts.grid_format, prompt=original_prompt, seed=processed.seed, grid=True, p=p) + images.save_image(processed.images[0], p.outpath_grids, "prompt_matrix", + extension=opts.grid_format, prompt=original_prompt, seed=processed.seed, grid=True, p=p) return processed From 5997457fd48c9c0fc31c9d96bf0b9c217585c526 Mon Sep 17 00:00:00 2001 From: Yevhenii Hurin Date: Sun, 29 Jan 2023 16:23:29 +0200 Subject: [PATCH 123/151] Compact options UI for Prompt Matrix --- scripts/prompt_matrix.py | 27 +++++++++++++++++---------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/scripts/prompt_matrix.py b/scripts/prompt_matrix.py index 89db9e63..03212b31 100644 --- a/scripts/prompt_matrix.py +++ b/scripts/prompt_matrix.py @@ -47,16 +47,23 @@ class Script(scripts.Script): return "Prompt matrix" def ui(self, is_img2img): - put_at_start = gr.Checkbox(label='Put variable parts at start of prompt', - value=False, elem_id=self.elem_id("put_at_start")) - different_seeds = gr.Checkbox( - label='Use different seed for each picture', value=False, elem_id=self.elem_id("different_seeds")) - # Radio buttons for selecting the prompt between positive and negative - prompt_type = gr.Radio(["positive", "negative"], label="Select prompt", - elem_id=self.elem_id("prompt_type"), value="positive") - # Radio buttons for selecting the delimiter to use in the resulting prompt - variations_delimiter = gr.Radio(["comma", "space"], label="Select delimiter", elem_id=self.elem_id( - "variations_delimiter"), value="comma") + gr.HTML('
') + with gr.Row(): + with gr.Column(): + put_at_start = gr.Checkbox(label='Put variable parts at start of prompt', + value=False, elem_id=self.elem_id("put_at_start")) + with gr.Column(): + # Radio buttons for selecting the prompt between positive and negative + prompt_type = gr.Radio(["positive", "negative"], label="Select prompt", + elem_id=self.elem_id("prompt_type"), value="positive") + with gr.Row(): + with gr.Column(): + different_seeds = gr.Checkbox( + label='Use different seed for each picture', value=False, elem_id=self.elem_id("different_seeds")) + with gr.Column(): + # Radio buttons for selecting the delimiter to use in the resulting prompt + variations_delimiter = gr.Radio(["comma", "space"], label="Select delimiter", elem_id=self.elem_id( + "variations_delimiter"), value="comma") return [put_at_start, different_seeds, prompt_type, variations_delimiter] def run(self, p, put_at_start, different_seeds, prompt_type, variations_delimiter): From 1e2b10d2dcdf41a6cce0c525c85ebd42a521e0f1 Mon Sep 17 00:00:00 2001 From: Yevhenii Hurin Date: Sun, 29 Jan 2023 17:14:46 +0200 Subject: [PATCH 124/151] Cleanup changes made by formatter --- scripts/prompt_matrix.py | 25 ++++++++----------------- 1 file changed, 8 insertions(+), 17 deletions(-) diff --git a/scripts/prompt_matrix.py b/scripts/prompt_matrix.py index 03212b31..de921ea8 100644 --- a/scripts/prompt_matrix.py +++ b/scripts/prompt_matrix.py @@ -10,7 +10,6 @@ from modules import images from modules.processing import process_images, Processed from modules.shared import opts, cmd_opts, state import modules.sd_samplers -from pprint import pprint def draw_xy_grid(xs, ys, x_label, y_label, cell): @@ -34,8 +33,7 @@ def draw_xy_grid(xs, ys, x_label, y_label, cell): res.append(processed.images[0]) grid = images.image_grid(res, rows=len(ys)) - grid = images.draw_grid_annotations( - grid, res[0].width, res[0].height, hor_texts, ver_texts) + grid = images.draw_grid_annotations(grid, res[0].width, res[0].height, hor_texts, ver_texts) first_processed.images = [grid] @@ -73,8 +71,7 @@ class Script(scripts.Script): raise ValueError(f"Unknown prompt type {prompt_type}") # Raise error if variations delimiter is not comma or space if variations_delimiter not in ["comma", "space"]: - raise ValueError( - f"Unknown variations delimiter {variations_delimiter}") + raise ValueError(f"Unknown variations delimiter {variations_delimiter}") prompt = p.prompt if prompt_type == "positive" else p.negative_prompt original_prompt = prompt[0] if type(prompt) == list else prompt @@ -86,8 +83,7 @@ class Script(scripts.Script): prompt_matrix_parts = original_prompt.split("|") combination_count = 2 ** (len(prompt_matrix_parts) - 1) for combination_num in range(combination_count): - selected_prompts = [text.strip().strip(',') for n, text in enumerate( - prompt_matrix_parts[1:]) if combination_num & (1 << n)] + selected_prompts = [text.strip().strip(',') for n, text in enumerate(prompt_matrix_parts[1:]) if combination_num & (1 << n)] if put_at_start: selected_prompts = selected_prompts + [prompt_matrix_parts[0]] @@ -99,28 +95,23 @@ class Script(scripts.Script): p.n_iter = math.ceil(len(all_prompts) / p.batch_size) p.do_not_save_grid = True - print( - f"Prompt matrix will create {len(all_prompts)} images using a total of {p.n_iter} batches.") + print(f"Prompt matrix will create {len(all_prompts)} images using a total of {p.n_iter} batches.") if prompt_type == "positive": p.prompt = all_prompts else: p.negative_prompt = all_prompts - p.seed = [p.seed + (i if different_seeds else 0) - for i in range(len(all_prompts))] + p.seed = [p.seed + (i if different_seeds else 0) for i in range(len(all_prompts))] p.prompt_for_display = positive_prompt processed = process_images(p) - grid = images.image_grid(processed.images, p.batch_size, rows=1 << ( - (len(prompt_matrix_parts) - 1) // 2)) - grid = images.draw_prompt_matrix( - grid, p.width, p.height, prompt_matrix_parts) + grid = images.image_grid(processed.images, p.batch_size, rows=1 << ((len(prompt_matrix_parts) - 1) // 2)) + grid = images.draw_prompt_matrix(grid, p.width, p.height, prompt_matrix_parts) processed.images.insert(0, grid) processed.index_of_first_image = 1 processed.infotexts.insert(0, processed.infotexts[0]) if opts.grid_save: - images.save_image(processed.images[0], p.outpath_grids, "prompt_matrix", - extension=opts.grid_format, prompt=original_prompt, seed=processed.seed, grid=True, p=p) + images.save_image(processed.images[0], p.outpath_grids, "prompt_matrix", extension=opts.grid_format, prompt=original_prompt, seed=processed.seed, grid=True, p=p) return processed From 938578e8a94883aa3c0075cf47eea64f66119541 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 30 Jan 2023 00:25:30 +0300 Subject: [PATCH 125/151] make it so that setting options in pasted infotext (like Clip Skip and ENSD) do not get applied directly and instead are added as temporary overrides --- modules/generation_parameters_copypaste.py | 201 ++++++++++++++------- modules/shared.py | 37 +++- modules/txt2img.py | 6 +- modules/ui.py | 40 +++- modules/ui_common.py | 6 +- 5 files changed, 210 insertions(+), 80 deletions(-) diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index 3c098e0d..1292fead 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -1,4 +1,5 @@ import base64 +import html import io import math import os @@ -16,13 +17,23 @@ re_param = re.compile(re_param_code) re_imagesize = re.compile(r"^(\d+)x(\d+)$") re_hypernet_hash = re.compile("\(([0-9a-f]+)\)$") type_of_gr_update = type(gr.update()) + paste_fields = {} -bind_list = [] +registered_param_bindings = [] + + +class ParamBinding: + def __init__(self, paste_button, tabname, source_text_component=None, source_image_component=None, source_tabname=None, override_settings_component=None): + self.paste_button = paste_button + self.tabname = tabname + self.source_text_component = source_text_component + self.source_image_component = source_image_component + self.source_tabname = source_tabname + self.override_settings_component = override_settings_component def reset(): paste_fields.clear() - bind_list.clear() def quote(text): @@ -74,26 +85,6 @@ def add_paste_fields(tabname, init_img, fields): modules.ui.img2img_paste_fields = fields -def integrate_settings_paste_fields(component_dict): - from modules import ui - - settings_map = { - 'CLIP_stop_at_last_layers': 'Clip skip', - 'inpainting_mask_weight': 'Conditional mask weight', - 'sd_model_checkpoint': 'Model hash', - 'eta_noise_seed_delta': 'ENSD', - 'initial_noise_multiplier': 'Noise multiplier', - } - settings_paste_fields = [ - (component_dict[k], lambda d, k=k, v=v: ui.apply_setting(k, d.get(v, None))) - for k, v in settings_map.items() - ] - - for tabname, info in paste_fields.items(): - if info["fields"] is not None: - info["fields"] += settings_paste_fields - - def create_buttons(tabs_list): buttons = {} for tab in tabs_list: @@ -101,9 +92,60 @@ def create_buttons(tabs_list): return buttons -#if send_generate_info is a tab name, mean generate_info comes from the params fields of the tab def bind_buttons(buttons, send_image, send_generate_info): - bind_list.append([buttons, send_image, send_generate_info]) + """old function for backwards compatibility; do not use this, use register_paste_params_button""" + for tabname, button in buttons.items(): + source_text_component = send_generate_info if isinstance(send_generate_info, gr.components.Component) else None + source_tabname = send_generate_info if isinstance(send_generate_info, str) else None + + register_paste_params_button(ParamBinding(paste_button=button, tabname=tabname, source_text_component=source_text_component, source_image_component=send_image, source_tabname=source_tabname)) + + +def register_paste_params_button(binding: ParamBinding): + registered_param_bindings.append(binding) + + +def connect_paste_params_buttons(): + binding: ParamBinding + for binding in registered_param_bindings: + destination_image_component = paste_fields[binding.tabname]["init_img"] + fields = paste_fields[binding.tabname]["fields"] + + destination_width_component = next(iter([field for field, name in fields if name == "Size-1"] if fields else []), None) + destination_height_component = next(iter([field for field, name in fields if name == "Size-2"] if fields else []), None) + + if binding.source_image_component and destination_image_component: + if isinstance(binding.source_image_component, gr.Gallery): + func = send_image_and_dimensions if destination_width_component else image_from_url_text + jsfunc = "extract_image_from_gallery" + else: + func = send_image_and_dimensions if destination_width_component else lambda x: x + jsfunc = None + + binding.paste_button.click( + fn=func, + _js=jsfunc, + inputs=[binding.source_image_component], + outputs=[destination_image_component, destination_width_component, destination_height_component] if destination_width_component else [destination_image_component], + ) + + if binding.source_text_component is not None and fields is not None: + connect_paste(binding.paste_button, fields, binding.source_text_component, binding.override_settings_component) + + if binding.source_tabname is not None and fields is not None: + paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration'] + (["Seed"] if shared.opts.send_seed else []) + binding.paste_button.click( + fn=lambda *x: x, + inputs=[field for field, name in paste_fields[binding.source_tabname]["fields"] if name in paste_field_names], + outputs=[field for field, name in fields if name in paste_field_names], + ) + + binding.paste_button.click( + fn=None, + _js=f"switch_to_{binding.tabname}", + inputs=None, + outputs=None, + ) def send_image_and_dimensions(x): @@ -122,49 +164,6 @@ def send_image_and_dimensions(x): return img, w, h -def run_bind(): - for buttons, source_image_component, send_generate_info in bind_list: - for tab in buttons: - button = buttons[tab] - destination_image_component = paste_fields[tab]["init_img"] - fields = paste_fields[tab]["fields"] - - destination_width_component = next(iter([field for field, name in fields if name == "Size-1"] if fields else []), None) - destination_height_component = next(iter([field for field, name in fields if name == "Size-2"] if fields else []), None) - - if source_image_component and destination_image_component: - if isinstance(source_image_component, gr.Gallery): - func = send_image_and_dimensions if destination_width_component else image_from_url_text - jsfunc = "extract_image_from_gallery" - else: - func = send_image_and_dimensions if destination_width_component else lambda x: x - jsfunc = None - - button.click( - fn=func, - _js=jsfunc, - inputs=[source_image_component], - outputs=[destination_image_component, destination_width_component, destination_height_component] if destination_width_component else [destination_image_component], - ) - - if send_generate_info and fields is not None: - if send_generate_info in paste_fields: - paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration'] + (["Seed"] if shared.opts.send_seed else []) - button.click( - fn=lambda *x: x, - inputs=[field for field, name in paste_fields[send_generate_info]["fields"] if name in paste_field_names], - outputs=[field for field, name in fields if name in paste_field_names], - ) - else: - connect_paste(button, fields, send_generate_info) - - button.click( - fn=None, - _js=f"switch_to_{tab}", - inputs=None, - outputs=None, - ) - def find_hypernetwork_key(hypernet_name, hypernet_hash=None): """Determines the config parameter name to use for the hypernet based on the parameters in the infotext. @@ -286,7 +285,47 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model return res -def connect_paste(button, paste_fields, input_comp, jsfunc=None): +settings_map = {} + +infotext_to_setting_name_mapping = [ + ('Clip skip', 'CLIP_stop_at_last_layers', ), + ('Conditional mask weight', 'inpainting_mask_weight'), + ('Model hash', 'sd_model_checkpoint'), + ('ENSD', 'eta_noise_seed_delta'), + ('Noise multiplier', 'initial_noise_multiplier'), +] + + +def create_override_settings_dict(text_pairs): + """creates processing's override_settings parameters from gradio's multiselect + + Example input: + ['Clip skip: 2', 'Model hash: e6e99610c4', 'ENSD: 31337'] + + Example output: + {'CLIP_stop_at_last_layers': 2, 'sd_model_checkpoint': 'e6e99610c4', 'eta_noise_seed_delta': 31337} + """ + + res = {} + + params = {} + for pair in text_pairs: + k, v = pair.split(":", maxsplit=1) + + params[k] = v.strip() + + for param_name, setting_name in infotext_to_setting_name_mapping: + value = params.get(param_name, None) + + if value is None: + continue + + res[setting_name] = shared.opts.cast_value(setting_name, value) + + return res + + +def connect_paste(button, paste_fields, input_comp, override_settings_component, jsfunc=None): def paste_func(prompt): if not prompt and not shared.cmd_opts.hide_ui_dir_config: filename = os.path.join(data_path, "params.txt") @@ -323,6 +362,32 @@ def connect_paste(button, paste_fields, input_comp, jsfunc=None): return res + if override_settings_component is not None: + def paste_settings(params): + vals = {} + + for param_name, setting_name in infotext_to_setting_name_mapping: + v = params.get(param_name, None) + if v is None: + continue + + if setting_name == "sd_model_checkpoint" and shared.opts.disable_weights_auto_swap: + continue + + v = shared.opts.cast_value(setting_name, v) + current_value = getattr(shared.opts, setting_name, None) + + if v == current_value: + continue + + vals[param_name] = v + + vals_pairs = [f"{k}: {v}" for k, v in vals.items()] + + return gr.Dropdown.update(value=vals_pairs, choices=vals_pairs, visible=len(vals_pairs) > 0) + + paste_fields = paste_fields + [(override_settings_component, paste_settings)] + button.click( fn=paste_func, _js=jsfunc, diff --git a/modules/shared.py b/modules/shared.py index eb04e811..b5370265 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -127,12 +127,13 @@ restricted_opts = { ui_reorder_categories = [ "inpaint", "sampler", + "checkboxes", + "hires_fix", "dimensions", "cfg", "seed", - "checkboxes", - "hires_fix", "batch", + "override_settings", "scripts", ] @@ -346,10 +347,10 @@ options_templates.update(options_section(('saving-paths', "Paths for saving"), { })) options_templates.update(options_section(('saving-to-dirs', "Saving to a directory"), { - "save_to_dirs": OptionInfo(False, "Save images to a subdirectory"), - "grid_save_to_dirs": OptionInfo(False, "Save grids to a subdirectory"), + "save_to_dirs": OptionInfo(True, "Save images to a subdirectory"), + "grid_save_to_dirs": OptionInfo(True, "Save grids to a subdirectory"), "use_save_to_dirs_for_ui": OptionInfo(False, "When using \"Save\" button, save images to a subdirectory"), - "directories_filename_pattern": OptionInfo("", "Directory name pattern", component_args=hide_dirs), + "directories_filename_pattern": OptionInfo("[date]", "Directory name pattern", component_args=hide_dirs), "directories_max_prompt_words": OptionInfo(8, "Max prompt words for [prompt_words] pattern", gr.Slider, {"minimum": 1, "maximum": 20, "step": 1, **hide_dirs}), })) @@ -605,11 +606,37 @@ class Options: self.data_labels = {k: v for k, v in sorted(settings_items, key=lambda x: section_ids[x[1].section])} + def cast_value(self, key, value): + """casts an arbitrary to the same type as this setting's value with key + Example: cast_value("eta_noise_seed_delta", "12") -> returns 12 (an int rather than str) + """ + + if value is None: + return None + + default_value = self.data_labels[key].default + if default_value is None: + default_value = getattr(self, key, None) + if default_value is None: + return None + + expected_type = type(default_value) + if expected_type == bool and value == "False": + value = False + else: + value = expected_type(value) + + return value + + opts = Options() if os.path.exists(config_filename): opts.load(config_filename) +settings_components = None +"""assinged from ui.py, a mapping on setting anmes to gradio components repsponsible for those settings""" + latent_upscale_default_mode = "Latent" latent_upscale_modes = { "Latent": {"mode": "bilinear", "antialias": False}, diff --git a/modules/txt2img.py b/modules/txt2img.py index e945fd69..16841d0f 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -1,5 +1,6 @@ import modules.scripts from modules import sd_samplers +from modules.generation_parameters_copypaste import create_override_settings_dict from modules.processing import StableDiffusionProcessing, Processed, StableDiffusionProcessingTxt2Img, \ StableDiffusionProcessingImg2Img, process_images from modules.shared import opts, cmd_opts @@ -8,7 +9,9 @@ import modules.processing as processing from modules.ui import plaintext_to_html -def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, *args): +def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, override_settings_texts, *args): + override_settings = create_override_settings_dict(override_settings_texts) + p = StableDiffusionProcessingTxt2Img( sd_model=shared.sd_model, outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples, @@ -38,6 +41,7 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, step hr_second_pass_steps=hr_second_pass_steps, hr_resize_x=hr_resize_x, hr_resize_y=hr_resize_y, + override_settings=override_settings, ) p.scripts = modules.scripts.scripts_txt2img diff --git a/modules/ui.py b/modules/ui.py index f1195692..a7fcdd83 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -380,6 +380,7 @@ def apply_setting(key, value): opts.save(shared.config_filename) return getattr(opts, key) + def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id): def refresh(): refresh_method() @@ -433,6 +434,18 @@ def get_value_for_setting(key): return gr.update(value=value, **args) +def create_override_settings_dropdown(tabname, row): + dropdown = gr.Dropdown([], label="Override settings", visible=False, elem_id=f"{tabname}_override_settings", multiselect=True) + + dropdown.change( + fn=lambda x: gr.Dropdown.update(visible=len(x) > 0), + inputs=[dropdown], + outputs=[dropdown], + ) + + return dropdown + + def create_ui(): import modules.img2img import modules.txt2img @@ -503,6 +516,10 @@ def create_ui(): batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count") batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size") + elif category == "override_settings": + with FormRow(elem_id="txt2img_override_settings_row") as row: + override_settings = create_override_settings_dropdown('txt2img', row) + elif category == "scripts": with FormGroup(elem_id="txt2img_script_container"): custom_inputs = modules.scripts.scripts_txt2img.setup_ui() @@ -524,7 +541,6 @@ def create_ui(): ) txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples) - parameters_copypaste.bind_buttons({"txt2img": txt2img_paste}, None, txt2img_prompt) connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False) connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True) @@ -555,6 +571,7 @@ def create_ui(): hr_second_pass_steps, hr_resize_x, hr_resize_y, + override_settings, ] + custom_inputs, outputs=[ @@ -615,6 +632,9 @@ def create_ui(): *modules.scripts.scripts_txt2img.infotext_fields ] parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields) + parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding( + paste_button=txt2img_paste, tabname="txt2img", source_text_component=txt2img_prompt, source_image_component=None, override_settings_component=override_settings, + )) txt2img_preview_params = [ txt2img_prompt, @@ -762,6 +782,10 @@ def create_ui(): batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count") batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size") + elif category == "override_settings": + with FormRow(elem_id="img2img_override_settings_row") as row: + override_settings = create_override_settings_dropdown('img2img', row) + elif category == "scripts": with FormGroup(elem_id="img2img_script_container"): custom_inputs = modules.scripts.scripts_img2img.setup_ui() @@ -796,7 +820,6 @@ def create_ui(): ) img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples) - parameters_copypaste.bind_buttons({"img2img": img2img_paste}, None, img2img_prompt) connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False) connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True) @@ -937,6 +960,9 @@ def create_ui(): ] parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields) parameters_copypaste.add_paste_fields("inpaint", init_img_with_mask, img2img_paste_fields) + parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding( + paste_button=img2img_paste, tabname="img2img", source_text_component=img2img_prompt, source_image_component=None, + )) modules.scripts.scripts_current = None @@ -954,7 +980,11 @@ def create_ui(): html2 = gr.HTML() with gr.Row(): buttons = parameters_copypaste.create_buttons(["txt2img", "img2img", "inpaint", "extras"]) - parameters_copypaste.bind_buttons(buttons, image, generation_info) + + for tabname, button in buttons.items(): + parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding( + paste_button=button, tabname=tabname, source_text_component=generation_info, source_image_component=image, + )) image.change( fn=wrap_gradio_call(modules.extras.run_pnginfo), @@ -1363,6 +1393,7 @@ def create_ui(): components = [] component_dict = {} + shared.settings_components = component_dict script_callbacks.ui_settings_callback() opts.reorder() @@ -1529,8 +1560,7 @@ def create_ui(): component = create_setting_component(k, is_quicksettings=True) component_dict[k] = component - parameters_copypaste.integrate_settings_paste_fields(component_dict) - parameters_copypaste.run_bind() + parameters_copypaste.connect_paste_params_buttons() with gr.Tabs(elem_id="tabs") as tabs: for interface, label, ifid in interfaces: diff --git a/modules/ui_common.py b/modules/ui_common.py index 9405ac1f..fd047f31 100644 --- a/modules/ui_common.py +++ b/modules/ui_common.py @@ -198,5 +198,9 @@ Requested path was: {f} html_info = gr.HTML(elem_id=f'html_info_{tabname}') html_log = gr.HTML(elem_id=f'html_log_{tabname}') - parameters_copypaste.bind_buttons(buttons, result_gallery, "txt2img" if tabname == "txt2img" else None) + for paste_tabname, paste_button in buttons.items(): + parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding( + paste_button=paste_button, tabname=paste_tabname, source_tabname="txt2img" if tabname == "txt2img" else None, source_image_component=result_gallery + )) + return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log From f91068f426a359942d13bf7ec15b56562141b8d7 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 30 Jan 2023 00:37:26 +0300 Subject: [PATCH 126/151] change disable_weights_auto_swap to true by default --- modules/shared.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/shared.py b/modules/shared.py index b5370265..96a2572f 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -441,7 +441,7 @@ options_templates.update(options_section(('ui', "User interface"), { "do_not_show_images": OptionInfo(False, "Do not show any images in results for web"), "add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"), "add_model_name_to_info": OptionInfo(True, "Add model name to generation information"), - "disable_weights_auto_swap": OptionInfo(False, "When reading generation parameters from text into UI (from PNG info or pasted text), do not change the selected model/checkpoint."), + "disable_weights_auto_swap": OptionInfo(True, "When reading generation parameters from text into UI (from PNG info or pasted text), do not change the selected model/checkpoint."), "send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"), "send_size": OptionInfo(True, "Send size when sending prompt or image to another interface"), "font": OptionInfo("", "Font for image grids that have text"), From 399720dac2543fb4cdbe1022ec1a01f2411b81e2 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 30 Jan 2023 01:03:31 +0300 Subject: [PATCH 127/151] update prompt token counts after using the paste params button --- javascript/ui.js | 36 +++++++++++++++++----- modules/generation_parameters_copypaste.py | 6 ++-- 2 files changed, 31 insertions(+), 11 deletions(-) diff --git a/javascript/ui.js b/javascript/ui.js index dd40e62d..b7a8268a 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -191,6 +191,28 @@ function confirm_clear_prompt(prompt, negative_prompt) { return [prompt, negative_prompt] } + +promptTokecountUpdateFuncs = {} + +function recalculatePromptTokens(name){ + if(promptTokecountUpdateFuncs[name]){ + promptTokecountUpdateFuncs[name]() + } +} + +function recalculate_prompts_txt2img(){ + recalculatePromptTokens('txt2img_prompt') + recalculatePromptTokens('txt2img_neg_prompt') + return args_to_array(arguments); +} + +function recalculate_prompts_img2img(){ + recalculatePromptTokens('img2img_prompt') + recalculatePromptTokens('img2img_neg_prompt') + return args_to_array(arguments); +} + + opts = {} onUiUpdate(function(){ if(Object.keys(opts).length != 0) return; @@ -232,14 +254,12 @@ onUiUpdate(function(){ return } - prompt.parentElement.insertBefore(counter, prompt) counter.classList.add("token-counter") prompt.parentElement.style.position = "relative" - textarea.addEventListener("input", function(){ - update_token_counter(id_button); - }); + promptTokecountUpdateFuncs[id] = function(){ update_token_counter(id_button); } + textarea.addEventListener("input", promptTokecountUpdateFuncs[id]); } registerTextarea('txt2img_prompt', 'txt2img_token_counter', 'txt2img_token_button') @@ -273,7 +293,7 @@ onOptionsChanged(function(){ let txt2img_textarea, img2img_textarea = undefined; let wait_time = 800 -let token_timeout; +let token_timeouts = {}; function update_txt2img_tokens(...args) { update_token_counter("txt2img_token_button") @@ -290,9 +310,9 @@ function update_img2img_tokens(...args) { } function update_token_counter(button_id) { - if (token_timeout) - clearTimeout(token_timeout); - token_timeout = setTimeout(() => gradioApp().getElementById(button_id)?.click(), wait_time); + if (token_timeouts[button_id]) + clearTimeout(token_timeouts[button_id]); + token_timeouts[button_id] = setTimeout(() => gradioApp().getElementById(button_id)?.click(), wait_time); } function restart_reload(){ diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index 1292fead..2a10524f 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -130,7 +130,7 @@ def connect_paste_params_buttons(): ) if binding.source_text_component is not None and fields is not None: - connect_paste(binding.paste_button, fields, binding.source_text_component, binding.override_settings_component) + connect_paste(binding.paste_button, fields, binding.source_text_component, binding.override_settings_component, binding.tabname) if binding.source_tabname is not None and fields is not None: paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration'] + (["Seed"] if shared.opts.send_seed else []) @@ -325,7 +325,7 @@ def create_override_settings_dict(text_pairs): return res -def connect_paste(button, paste_fields, input_comp, override_settings_component, jsfunc=None): +def connect_paste(button, paste_fields, input_comp, override_settings_component, tabname): def paste_func(prompt): if not prompt and not shared.cmd_opts.hide_ui_dir_config: filename = os.path.join(data_path, "params.txt") @@ -390,7 +390,7 @@ def connect_paste(button, paste_fields, input_comp, override_settings_component, button.click( fn=paste_func, - _js=jsfunc, + _js=f"recalculate_prompts_{tabname}", inputs=[input_comp], outputs=[x[0] for x in paste_fields], ) From 847ceae1f71ee13e0a397da048d1bb418e8f36c1 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 30 Jan 2023 01:41:23 +0300 Subject: [PATCH 128/151] make it possible to search checkpoint by its hash --- modules/ui_extra_networks_checkpoints.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/modules/ui_extra_networks_checkpoints.py b/modules/ui_extra_networks_checkpoints.py index a6799171..04097a79 100644 --- a/modules/ui_extra_networks_checkpoints.py +++ b/modules/ui_extra_networks_checkpoints.py @@ -14,6 +14,7 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage): shared.refresh_checkpoints() def list_items(self): + checkpoint: sd_models.CheckpointInfo for name, checkpoint in sd_models.checkpoints_list.items(): path, ext = os.path.splitext(checkpoint.filename) previews = [path + ".png", path + ".preview.png"] @@ -28,7 +29,7 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage): "name": checkpoint.name_for_extra, "filename": path, "preview": preview, - "search_term": self.search_terms_from_path(checkpoint.filename), + "search_term": self.search_terms_from_path(checkpoint.filename) + " " + (checkpoint.sha256 or ""), "onclick": '"' + html.escape(f"""return selectCheckpoint({json.dumps(name)})""") + '"', "local_preview": path + ".png", } From c81b52ffbd6252842b3473a7aa8eb7ffc88ee7d1 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 30 Jan 2023 02:40:26 +0300 Subject: [PATCH 129/151] add override settings component to img2img --- modules/img2img.py | 6 +++++- modules/ui.py | 5 +++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/modules/img2img.py b/modules/img2img.py index 3ecb6146..f813299c 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -7,6 +7,7 @@ import numpy as np from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops from modules import devices, sd_samplers +from modules.generation_parameters_copypaste import create_override_settings_dict from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images from modules.shared import opts, state import modules.shared as shared @@ -75,7 +76,9 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args): processed_image.save(os.path.join(output_dir, filename)) -def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, *args): +def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, *args): + override_settings = create_override_settings_dict(override_settings_texts) + is_batch = mode == 5 if mode == 0: # img2img @@ -142,6 +145,7 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s inpaint_full_res=inpaint_full_res, inpaint_full_res_padding=inpaint_full_res_padding, inpainting_mask_invert=inpainting_mask_invert, + override_settings=override_settings, ) p.scripts = modules.scripts.scripts_txt2img diff --git a/modules/ui.py b/modules/ui.py index a7fcdd83..f910c582 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -872,7 +872,8 @@ def create_ui(): inpainting_mask_invert, img2img_batch_input_dir, img2img_batch_output_dir, - img2img_batch_inpaint_mask_dir + img2img_batch_inpaint_mask_dir, + override_settings, ] + custom_inputs, outputs=[ img2img_gallery, @@ -961,7 +962,7 @@ def create_ui(): parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields) parameters_copypaste.add_paste_fields("inpaint", init_img_with_mask, img2img_paste_fields) parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding( - paste_button=img2img_paste, tabname="img2img", source_text_component=img2img_prompt, source_image_component=None, + paste_button=img2img_paste, tabname="img2img", source_text_component=img2img_prompt, source_image_component=None, override_settings_component=override_settings, )) modules.scripts.scripts_current = None From cbd6329488beafe036ea3a3d0cea1a6940105cf9 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 30 Jan 2023 09:12:43 +0300 Subject: [PATCH 130/151] add an environment variable for selecting xformers package --- launch.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/launch.py b/launch.py index 370920de..25909469 100644 --- a/launch.py +++ b/launch.py @@ -223,6 +223,7 @@ def prepare_environment(): requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt") commandline_args = os.environ.get('COMMANDLINE_ARGS', "") + xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.16rc425') gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "git+https://github.com/TencentARC/GFPGAN.git@8d2447a2d918f8eba5a4a01463fd48e45126a379") clip_package = os.environ.get('CLIP_PACKAGE', "git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1") openclip_package = os.environ.get('OPENCLIP_PACKAGE', "git+https://github.com/mlfoundations/open_clip.git@bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b") @@ -282,7 +283,7 @@ def prepare_environment(): if (not is_installed("xformers") or reinstall_xformers) and xformers: if platform.system() == "Windows": if platform.python_version().startswith("3.10"): - run_pip(f"install -U -I --no-deps xformers==0.0.16rc425", "xformers") + run_pip(f"install -U -I --no-deps {xformers_package}", "xformers") else: print("Installation of xformers is not supported in this version of Python.") print("You can also check this and build manually: https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers#building-xformers-on-windows-by-duckness") From 0c7c36a6c6f12da55e04bd79ae068daac8b586a1 Mon Sep 17 00:00:00 2001 From: Andrey <16777216c@gmail.com> Date: Mon, 30 Jan 2023 09:35:52 +0300 Subject: [PATCH 131/151] Split history sd_samplers.py to sd_samplers_compvis.py --- modules/{sd_samplers.py => sd_samplers_compvis.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename modules/{sd_samplers.py => sd_samplers_compvis.py} (100%) diff --git a/modules/sd_samplers.py b/modules/sd_samplers_compvis.py similarity index 100% rename from modules/sd_samplers.py rename to modules/sd_samplers_compvis.py From 9118b086068253c8f25c6277c385606b79c5b036 Mon Sep 17 00:00:00 2001 From: Andrey <16777216c@gmail.com> Date: Mon, 30 Jan 2023 09:35:52 +0300 Subject: [PATCH 132/151] Split history sd_samplers.py to sd_samplers_compvis.py --- modules/{sd_samplers.py => temp} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename modules/{sd_samplers.py => temp} (100%) diff --git a/modules/sd_samplers.py b/modules/temp similarity index 100% rename from modules/sd_samplers.py rename to modules/temp From 449531a6c59b030b1cd7c3cba1113c47e0fc1c7d Mon Sep 17 00:00:00 2001 From: Andrey <16777216c@gmail.com> Date: Mon, 30 Jan 2023 09:35:53 +0300 Subject: [PATCH 133/151] Split history sd_samplers.py to sd_samplers_compvis.py --- modules/{temp => sd_samplers.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename modules/{temp => sd_samplers.py} (100%) diff --git a/modules/temp b/modules/sd_samplers.py similarity index 100% rename from modules/temp rename to modules/sd_samplers.py From 5feae71dd218a3505f14505d71c6b335f9c642ac Mon Sep 17 00:00:00 2001 From: Andrey <16777216c@gmail.com> Date: Mon, 30 Jan 2023 09:37:50 +0300 Subject: [PATCH 134/151] Split history sd_samplers.py to sd_samplers_common.py --- modules/{sd_samplers.py => sd_samplers_common.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename modules/{sd_samplers.py => sd_samplers_common.py} (100%) diff --git a/modules/sd_samplers.py b/modules/sd_samplers_common.py similarity index 100% rename from modules/sd_samplers.py rename to modules/sd_samplers_common.py From 6e78f6a8961875df11551650b4c5c8bddb6ed9ce Mon Sep 17 00:00:00 2001 From: Andrey <16777216c@gmail.com> Date: Mon, 30 Jan 2023 09:37:50 +0300 Subject: [PATCH 135/151] Split history sd_samplers.py to sd_samplers_common.py --- modules/{sd_samplers.py => temp} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename modules/{sd_samplers.py => temp} (100%) diff --git a/modules/sd_samplers.py b/modules/temp similarity index 100% rename from modules/sd_samplers.py rename to modules/temp From f8fcad502ec97ceb7ca4bf52f0f2efc8b80c0b64 Mon Sep 17 00:00:00 2001 From: Andrey <16777216c@gmail.com> Date: Mon, 30 Jan 2023 09:37:51 +0300 Subject: [PATCH 136/151] Split history sd_samplers.py to sd_samplers_common.py --- modules/{temp => sd_samplers.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename modules/{temp => sd_samplers.py} (100%) diff --git a/modules/temp b/modules/sd_samplers.py similarity index 100% rename from modules/temp rename to modules/sd_samplers.py From aa54a9d41680051b4b28b0655f8d61a2f23600b1 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 30 Jan 2023 09:51:06 +0300 Subject: [PATCH 137/151] split compvis sampler and shared sampler stuff into their own files --- modules/sd_samplers.py | 243 ++--------------- modules/sd_samplers_common.py | 479 +-------------------------------- modules/sd_samplers_compvis.py | 423 +---------------------------- 3 files changed, 28 insertions(+), 1117 deletions(-) diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index a7910b56..9a29f1ae 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -1,22 +1,18 @@ -from collections import namedtuple, deque -import numpy as np -from math import floor +from collections import deque import torch -import tqdm -from PIL import Image import inspect import k_diffusion.sampling -import torchsde._brownian.brownian_interval import ldm.models.diffusion.ddim import ldm.models.diffusion.plms -from modules import prompt_parser, devices, processing, images, sd_vae_approx +from modules import prompt_parser, devices, sd_samplers_common, sd_samplers_compvis -from modules.shared import opts, cmd_opts, state +from modules.shared import opts, state import modules.shared as shared from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback +# imports for functions that previously were here and are used by other modules +from modules.sd_samplers_common import samples_to_image_grid, sample_to_image -SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options']) samplers_k_diffusion = [ ('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {}), @@ -39,15 +35,15 @@ samplers_k_diffusion = [ ] samplers_data_k_diffusion = [ - SamplerData(label, lambda model, funcname=funcname: KDiffusionSampler(funcname, model), aliases, options) + sd_samplers_common.SamplerData(label, lambda model, funcname=funcname: KDiffusionSampler(funcname, model), aliases, options) for label, funcname, aliases, options in samplers_k_diffusion if hasattr(k_diffusion.sampling, funcname) ] all_samplers = [ *samplers_data_k_diffusion, - SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {}), - SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {}), + sd_samplers_common.SamplerData('DDIM', lambda model: sd_samplers_compvis.VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {}), + sd_samplers_common.SamplerData('PLMS', lambda model: sd_samplers_compvis.VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {}), ] all_samplers_map = {x.name: x for x in all_samplers} @@ -95,202 +91,6 @@ sampler_extra_params = { } -def setup_img2img_steps(p, steps=None): - if opts.img2img_fix_steps or steps is not None: - requested_steps = (steps or p.steps) - steps = int(requested_steps / min(p.denoising_strength, 0.999)) if p.denoising_strength > 0 else 0 - t_enc = requested_steps - 1 - else: - steps = p.steps - t_enc = int(min(p.denoising_strength, 0.999) * steps) - - return steps, t_enc - - -approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2} - - -def single_sample_to_image(sample, approximation=None): - if approximation is None: - approximation = approximation_indexes.get(opts.show_progress_type, 0) - - if approximation == 2: - x_sample = sd_vae_approx.cheap_approximation(sample) - elif approximation == 1: - x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach() - else: - x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0] - - x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0) - x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) - x_sample = x_sample.astype(np.uint8) - return Image.fromarray(x_sample) - - -def sample_to_image(samples, index=0, approximation=None): - return single_sample_to_image(samples[index], approximation) - - -def samples_to_image_grid(samples, approximation=None): - return images.image_grid([single_sample_to_image(sample, approximation) for sample in samples]) - - -def store_latent(decoded): - state.current_latent = decoded - - if opts.live_previews_enable and opts.show_progress_every_n_steps > 0 and shared.state.sampling_step % opts.show_progress_every_n_steps == 0: - if not shared.parallel_processing_allowed: - shared.state.assign_current_image(sample_to_image(decoded)) - - -class InterruptedException(BaseException): - pass - - -class VanillaStableDiffusionSampler: - def __init__(self, constructor, sd_model): - self.sampler = constructor(sd_model) - self.is_plms = hasattr(self.sampler, 'p_sample_plms') - self.orig_p_sample_ddim = self.sampler.p_sample_plms if self.is_plms else self.sampler.p_sample_ddim - self.mask = None - self.nmask = None - self.init_latent = None - self.sampler_noises = None - self.step = 0 - self.stop_at = None - self.eta = None - self.default_eta = 0.0 - self.config = None - self.last_latent = None - - self.conditioning_key = sd_model.model.conditioning_key - - def number_of_needed_noises(self, p): - return 0 - - def launch_sampling(self, steps, func): - state.sampling_steps = steps - state.sampling_step = 0 - - try: - return func() - except InterruptedException: - return self.last_latent - - def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs): - if state.interrupted or state.skipped: - raise InterruptedException - - if self.stop_at is not None and self.step > self.stop_at: - raise InterruptedException - - # Have to unwrap the inpainting conditioning here to perform pre-processing - image_conditioning = None - if isinstance(cond, dict): - image_conditioning = cond["c_concat"][0] - cond = cond["c_crossattn"][0] - unconditional_conditioning = unconditional_conditioning["c_crossattn"][0] - - conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step) - unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step) - - assert all([len(conds) == 1 for conds in conds_list]), 'composition via AND is not supported for DDIM/PLMS samplers' - cond = tensor - - # for DDIM, shapes must match, we can't just process cond and uncond independently; - # filling unconditional_conditioning with repeats of the last vector to match length is - # not 100% correct but should work well enough - if unconditional_conditioning.shape[1] < cond.shape[1]: - last_vector = unconditional_conditioning[:, -1:] - last_vector_repeated = last_vector.repeat([1, cond.shape[1] - unconditional_conditioning.shape[1], 1]) - unconditional_conditioning = torch.hstack([unconditional_conditioning, last_vector_repeated]) - elif unconditional_conditioning.shape[1] > cond.shape[1]: - unconditional_conditioning = unconditional_conditioning[:, :cond.shape[1]] - - if self.mask is not None: - img_orig = self.sampler.model.q_sample(self.init_latent, ts) - x_dec = img_orig * self.mask + self.nmask * x_dec - - # Wrap the image conditioning back up since the DDIM code can accept the dict directly. - # Note that they need to be lists because it just concatenates them later. - if image_conditioning is not None: - cond = {"c_concat": [image_conditioning], "c_crossattn": [cond]} - unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]} - - res = self.orig_p_sample_ddim(x_dec, cond, ts, unconditional_conditioning=unconditional_conditioning, *args, **kwargs) - - if self.mask is not None: - self.last_latent = self.init_latent * self.mask + self.nmask * res[1] - else: - self.last_latent = res[1] - - store_latent(self.last_latent) - - self.step += 1 - state.sampling_step = self.step - shared.total_tqdm.update() - - return res - - def initialize(self, p): - self.eta = p.eta if p.eta is not None else opts.eta_ddim - - for fieldname in ['p_sample_ddim', 'p_sample_plms']: - if hasattr(self.sampler, fieldname): - setattr(self.sampler, fieldname, self.p_sample_ddim_hook) - - self.mask = p.mask if hasattr(p, 'mask') else None - self.nmask = p.nmask if hasattr(p, 'nmask') else None - - def adjust_steps_if_invalid(self, p, num_steps): - if (self.config.name == 'DDIM' and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS'): - valid_step = 999 / (1000 // num_steps) - if valid_step == floor(valid_step): - return int(valid_step) + 1 - - return num_steps - - def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): - steps, t_enc = setup_img2img_steps(p, steps) - steps = self.adjust_steps_if_invalid(p, steps) - self.initialize(p) - - self.sampler.make_schedule(ddim_num_steps=steps, ddim_eta=self.eta, ddim_discretize=p.ddim_discretize, verbose=False) - x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(shared.device), noise=noise) - - self.init_latent = x - self.last_latent = x - self.step = 0 - - # Wrap the conditioning models with additional image conditioning for inpainting model - if image_conditioning is not None: - conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]} - unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]} - - samples = self.launch_sampling(t_enc + 1, lambda: self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning)) - - return samples - - def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): - self.initialize(p) - - self.init_latent = None - self.last_latent = x - self.step = 0 - - steps = self.adjust_steps_if_invalid(p, steps or p.steps) - - # Wrap the conditioning models with additional image conditioning for inpainting model - # dummy_for_plms is needed because PLMS code checks the first item in the dict to have the right shape - if image_conditioning is not None: - conditioning = {"dummy_for_plms": np.zeros((conditioning.shape[0],)), "c_crossattn": [conditioning], "c_concat": [image_conditioning]} - unconditional_conditioning = {"c_crossattn": [unconditional_conditioning], "c_concat": [image_conditioning]} - - samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0]) - - return samples_ddim - - class CFGDenoiser(torch.nn.Module): def __init__(self, model): super().__init__() @@ -312,7 +112,7 @@ class CFGDenoiser(torch.nn.Module): def forward(self, x, sigma, uncond, cond, cond_scale, image_cond): if state.interrupted or state.skipped: - raise InterruptedException + raise sd_samplers_common.InterruptedException conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step) uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step) @@ -354,9 +154,9 @@ class CFGDenoiser(torch.nn.Module): devices.test_for_nans(x_out, "unet") if opts.live_preview_content == "Prompt": - store_latent(x_out[0:uncond.shape[0]]) + sd_samplers_common.store_latent(x_out[0:uncond.shape[0]]) elif opts.live_preview_content == "Negative prompt": - store_latent(x_out[-uncond.shape[0]:]) + sd_samplers_common.store_latent(x_out[-uncond.shape[0]:]) denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale) @@ -395,19 +195,6 @@ class TorchHijack: return torch.randn_like(x) -# MPS fix for randn in torchsde -def torchsde_randn(size, dtype, device, seed): - if device.type == 'mps': - generator = torch.Generator(devices.cpu).manual_seed(int(seed)) - return torch.randn(size, dtype=dtype, device=devices.cpu, generator=generator).to(device) - else: - generator = torch.Generator(device).manual_seed(int(seed)) - return torch.randn(size, dtype=dtype, device=device, generator=generator) - - -torchsde._brownian.brownian_interval._randn = torchsde_randn - - class KDiffusionSampler: def __init__(self, funcname, sd_model): denoiser = k_diffusion.external.CompVisVDenoiser if sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser @@ -430,11 +217,11 @@ class KDiffusionSampler: step = d['i'] latent = d["denoised"] if opts.live_preview_content == "Combined": - store_latent(latent) + sd_samplers_common.store_latent(latent) self.last_latent = latent if self.stop_at is not None and step > self.stop_at: - raise InterruptedException + raise sd_samplers_common.InterruptedException state.sampling_step = step shared.total_tqdm.update() @@ -445,7 +232,7 @@ class KDiffusionSampler: try: return func() - except InterruptedException: + except sd_samplers_common.InterruptedException: return self.last_latent def number_of_needed_noises(self, p): @@ -492,7 +279,7 @@ class KDiffusionSampler: return sigmas def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): - steps, t_enc = setup_img2img_steps(p, steps) + steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps) sigmas = self.get_sigmas(p, steps) diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index a7910b56..5b06e341 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -1,99 +1,15 @@ from collections import namedtuple, deque import numpy as np -from math import floor import torch -import tqdm from PIL import Image -import inspect -import k_diffusion.sampling import torchsde._brownian.brownian_interval -import ldm.models.diffusion.ddim -import ldm.models.diffusion.plms -from modules import prompt_parser, devices, processing, images, sd_vae_approx +from modules import devices, processing, images, sd_vae_approx -from modules.shared import opts, cmd_opts, state +from modules.shared import opts, state import modules.shared as shared -from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback - SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options']) -samplers_k_diffusion = [ - ('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {}), - ('Euler', 'sample_euler', ['k_euler'], {}), - ('LMS', 'sample_lms', ['k_lms'], {}), - ('Heun', 'sample_heun', ['k_heun'], {}), - ('DPM2', 'sample_dpm_2', ['k_dpm_2'], {'discard_next_to_last_sigma': True}), - ('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'discard_next_to_last_sigma': True}), - ('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {}), - ('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}), - ('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {}), - ('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {}), - ('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {}), - ('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}), - ('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True}), - ('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True}), - ('DPM++ 2S a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras'}), - ('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}), - ('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras'}), -] - -samplers_data_k_diffusion = [ - SamplerData(label, lambda model, funcname=funcname: KDiffusionSampler(funcname, model), aliases, options) - for label, funcname, aliases, options in samplers_k_diffusion - if hasattr(k_diffusion.sampling, funcname) -] - -all_samplers = [ - *samplers_data_k_diffusion, - SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {}), - SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {}), -] -all_samplers_map = {x.name: x for x in all_samplers} - -samplers = [] -samplers_for_img2img = [] -samplers_map = {} - - -def create_sampler(name, model): - if name is not None: - config = all_samplers_map.get(name, None) - else: - config = all_samplers[0] - - assert config is not None, f'bad sampler name: {name}' - - sampler = config.constructor(model) - sampler.config = config - - return sampler - - -def set_samplers(): - global samplers, samplers_for_img2img - - hidden = set(opts.hide_samplers) - hidden_img2img = set(opts.hide_samplers + ['PLMS']) - - samplers = [x for x in all_samplers if x.name not in hidden] - samplers_for_img2img = [x for x in all_samplers if x.name not in hidden_img2img] - - samplers_map.clear() - for sampler in all_samplers: - samplers_map[sampler.name.lower()] = sampler.name - for alias in sampler.aliases: - samplers_map[alias.lower()] = sampler.name - - -set_samplers() - -sampler_extra_params = { - 'sample_euler': ['s_churn', 's_tmin', 's_tmax', 's_noise'], - 'sample_heun': ['s_churn', 's_tmin', 's_tmax', 's_noise'], - 'sample_dpm_2': ['s_churn', 's_tmin', 's_tmax', 's_noise'], -} - def setup_img2img_steps(p, steps=None): if opts.img2img_fix_steps or steps is not None: @@ -147,254 +63,6 @@ class InterruptedException(BaseException): pass -class VanillaStableDiffusionSampler: - def __init__(self, constructor, sd_model): - self.sampler = constructor(sd_model) - self.is_plms = hasattr(self.sampler, 'p_sample_plms') - self.orig_p_sample_ddim = self.sampler.p_sample_plms if self.is_plms else self.sampler.p_sample_ddim - self.mask = None - self.nmask = None - self.init_latent = None - self.sampler_noises = None - self.step = 0 - self.stop_at = None - self.eta = None - self.default_eta = 0.0 - self.config = None - self.last_latent = None - - self.conditioning_key = sd_model.model.conditioning_key - - def number_of_needed_noises(self, p): - return 0 - - def launch_sampling(self, steps, func): - state.sampling_steps = steps - state.sampling_step = 0 - - try: - return func() - except InterruptedException: - return self.last_latent - - def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs): - if state.interrupted or state.skipped: - raise InterruptedException - - if self.stop_at is not None and self.step > self.stop_at: - raise InterruptedException - - # Have to unwrap the inpainting conditioning here to perform pre-processing - image_conditioning = None - if isinstance(cond, dict): - image_conditioning = cond["c_concat"][0] - cond = cond["c_crossattn"][0] - unconditional_conditioning = unconditional_conditioning["c_crossattn"][0] - - conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step) - unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step) - - assert all([len(conds) == 1 for conds in conds_list]), 'composition via AND is not supported for DDIM/PLMS samplers' - cond = tensor - - # for DDIM, shapes must match, we can't just process cond and uncond independently; - # filling unconditional_conditioning with repeats of the last vector to match length is - # not 100% correct but should work well enough - if unconditional_conditioning.shape[1] < cond.shape[1]: - last_vector = unconditional_conditioning[:, -1:] - last_vector_repeated = last_vector.repeat([1, cond.shape[1] - unconditional_conditioning.shape[1], 1]) - unconditional_conditioning = torch.hstack([unconditional_conditioning, last_vector_repeated]) - elif unconditional_conditioning.shape[1] > cond.shape[1]: - unconditional_conditioning = unconditional_conditioning[:, :cond.shape[1]] - - if self.mask is not None: - img_orig = self.sampler.model.q_sample(self.init_latent, ts) - x_dec = img_orig * self.mask + self.nmask * x_dec - - # Wrap the image conditioning back up since the DDIM code can accept the dict directly. - # Note that they need to be lists because it just concatenates them later. - if image_conditioning is not None: - cond = {"c_concat": [image_conditioning], "c_crossattn": [cond]} - unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]} - - res = self.orig_p_sample_ddim(x_dec, cond, ts, unconditional_conditioning=unconditional_conditioning, *args, **kwargs) - - if self.mask is not None: - self.last_latent = self.init_latent * self.mask + self.nmask * res[1] - else: - self.last_latent = res[1] - - store_latent(self.last_latent) - - self.step += 1 - state.sampling_step = self.step - shared.total_tqdm.update() - - return res - - def initialize(self, p): - self.eta = p.eta if p.eta is not None else opts.eta_ddim - - for fieldname in ['p_sample_ddim', 'p_sample_plms']: - if hasattr(self.sampler, fieldname): - setattr(self.sampler, fieldname, self.p_sample_ddim_hook) - - self.mask = p.mask if hasattr(p, 'mask') else None - self.nmask = p.nmask if hasattr(p, 'nmask') else None - - def adjust_steps_if_invalid(self, p, num_steps): - if (self.config.name == 'DDIM' and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS'): - valid_step = 999 / (1000 // num_steps) - if valid_step == floor(valid_step): - return int(valid_step) + 1 - - return num_steps - - def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): - steps, t_enc = setup_img2img_steps(p, steps) - steps = self.adjust_steps_if_invalid(p, steps) - self.initialize(p) - - self.sampler.make_schedule(ddim_num_steps=steps, ddim_eta=self.eta, ddim_discretize=p.ddim_discretize, verbose=False) - x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(shared.device), noise=noise) - - self.init_latent = x - self.last_latent = x - self.step = 0 - - # Wrap the conditioning models with additional image conditioning for inpainting model - if image_conditioning is not None: - conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]} - unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]} - - samples = self.launch_sampling(t_enc + 1, lambda: self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning)) - - return samples - - def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): - self.initialize(p) - - self.init_latent = None - self.last_latent = x - self.step = 0 - - steps = self.adjust_steps_if_invalid(p, steps or p.steps) - - # Wrap the conditioning models with additional image conditioning for inpainting model - # dummy_for_plms is needed because PLMS code checks the first item in the dict to have the right shape - if image_conditioning is not None: - conditioning = {"dummy_for_plms": np.zeros((conditioning.shape[0],)), "c_crossattn": [conditioning], "c_concat": [image_conditioning]} - unconditional_conditioning = {"c_crossattn": [unconditional_conditioning], "c_concat": [image_conditioning]} - - samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0]) - - return samples_ddim - - -class CFGDenoiser(torch.nn.Module): - def __init__(self, model): - super().__init__() - self.inner_model = model - self.mask = None - self.nmask = None - self.init_latent = None - self.step = 0 - - def combine_denoised(self, x_out, conds_list, uncond, cond_scale): - denoised_uncond = x_out[-uncond.shape[0]:] - denoised = torch.clone(denoised_uncond) - - for i, conds in enumerate(conds_list): - for cond_index, weight in conds: - denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale) - - return denoised - - def forward(self, x, sigma, uncond, cond, cond_scale, image_cond): - if state.interrupted or state.skipped: - raise InterruptedException - - conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step) - uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step) - - batch_size = len(conds_list) - repeats = [len(conds_list[i]) for i in range(batch_size)] - - x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x]) - image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond]) - sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma]) - - denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps) - cfg_denoiser_callback(denoiser_params) - x_in = denoiser_params.x - image_cond_in = denoiser_params.image_cond - sigma_in = denoiser_params.sigma - - if tensor.shape[1] == uncond.shape[1]: - cond_in = torch.cat([tensor, uncond]) - - if shared.batch_cond_uncond: - x_out = self.inner_model(x_in, sigma_in, cond={"c_crossattn": [cond_in], "c_concat": [image_cond_in]}) - else: - x_out = torch.zeros_like(x_in) - for batch_offset in range(0, x_out.shape[0], batch_size): - a = batch_offset - b = a + batch_size - x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": [cond_in[a:b]], "c_concat": [image_cond_in[a:b]]}) - else: - x_out = torch.zeros_like(x_in) - batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size - for batch_offset in range(0, tensor.shape[0], batch_size): - a = batch_offset - b = min(a + batch_size, tensor.shape[0]) - x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": [tensor[a:b]], "c_concat": [image_cond_in[a:b]]}) - - x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond={"c_crossattn": [uncond], "c_concat": [image_cond_in[-uncond.shape[0]:]]}) - - devices.test_for_nans(x_out, "unet") - - if opts.live_preview_content == "Prompt": - store_latent(x_out[0:uncond.shape[0]]) - elif opts.live_preview_content == "Negative prompt": - store_latent(x_out[-uncond.shape[0]:]) - - denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale) - - if self.mask is not None: - denoised = self.init_latent * self.mask + self.nmask * denoised - - self.step += 1 - - return denoised - - -class TorchHijack: - def __init__(self, sampler_noises): - # Using a deque to efficiently receive the sampler_noises in the same order as the previous index-based - # implementation. - self.sampler_noises = deque(sampler_noises) - - def __getattr__(self, item): - if item == 'randn_like': - return self.randn_like - - if hasattr(torch, item): - return getattr(torch, item) - - raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item)) - - def randn_like(self, x): - if self.sampler_noises: - noise = self.sampler_noises.popleft() - if noise.shape == x.shape: - return noise - - if x.device.type == 'mps': - return torch.randn_like(x, device=devices.cpu).to(x.device) - else: - return torch.randn_like(x) - - # MPS fix for randn in torchsde def torchsde_randn(size, dtype, device, seed): if device.type == 'mps': @@ -407,146 +75,3 @@ def torchsde_randn(size, dtype, device, seed): torchsde._brownian.brownian_interval._randn = torchsde_randn - -class KDiffusionSampler: - def __init__(self, funcname, sd_model): - denoiser = k_diffusion.external.CompVisVDenoiser if sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser - - self.model_wrap = denoiser(sd_model, quantize=shared.opts.enable_quantization) - self.funcname = funcname - self.func = getattr(k_diffusion.sampling, self.funcname) - self.extra_params = sampler_extra_params.get(funcname, []) - self.model_wrap_cfg = CFGDenoiser(self.model_wrap) - self.sampler_noises = None - self.stop_at = None - self.eta = None - self.default_eta = 1.0 - self.config = None - self.last_latent = None - - self.conditioning_key = sd_model.model.conditioning_key - - def callback_state(self, d): - step = d['i'] - latent = d["denoised"] - if opts.live_preview_content == "Combined": - store_latent(latent) - self.last_latent = latent - - if self.stop_at is not None and step > self.stop_at: - raise InterruptedException - - state.sampling_step = step - shared.total_tqdm.update() - - def launch_sampling(self, steps, func): - state.sampling_steps = steps - state.sampling_step = 0 - - try: - return func() - except InterruptedException: - return self.last_latent - - def number_of_needed_noises(self, p): - return p.steps - - def initialize(self, p): - self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None - self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None - self.model_wrap_cfg.step = 0 - self.eta = p.eta or opts.eta_ancestral - - k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else []) - - extra_params_kwargs = {} - for param_name in self.extra_params: - if hasattr(p, param_name) and param_name in inspect.signature(self.func).parameters: - extra_params_kwargs[param_name] = getattr(p, param_name) - - if 'eta' in inspect.signature(self.func).parameters: - extra_params_kwargs['eta'] = self.eta - - return extra_params_kwargs - - def get_sigmas(self, p, steps): - discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False) - if opts.always_discard_next_to_last_sigma and not discard_next_to_last_sigma: - discard_next_to_last_sigma = True - p.extra_generation_params["Discard penultimate sigma"] = True - - steps += 1 if discard_next_to_last_sigma else 0 - - if p.sampler_noise_scheduler_override: - sigmas = p.sampler_noise_scheduler_override(steps) - elif self.config is not None and self.config.options.get('scheduler', None) == 'karras': - sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item()) - - sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, device=shared.device) - else: - sigmas = self.model_wrap.get_sigmas(steps) - - if discard_next_to_last_sigma: - sigmas = torch.cat([sigmas[:-2], sigmas[-1:]]) - - return sigmas - - def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): - steps, t_enc = setup_img2img_steps(p, steps) - - sigmas = self.get_sigmas(p, steps) - - sigma_sched = sigmas[steps - t_enc - 1:] - xi = x + noise * sigma_sched[0] - - extra_params_kwargs = self.initialize(p) - if 'sigma_min' in inspect.signature(self.func).parameters: - ## last sigma is zero which isn't allowed by DPM Fast & Adaptive so taking value before last - extra_params_kwargs['sigma_min'] = sigma_sched[-2] - if 'sigma_max' in inspect.signature(self.func).parameters: - extra_params_kwargs['sigma_max'] = sigma_sched[0] - if 'n' in inspect.signature(self.func).parameters: - extra_params_kwargs['n'] = len(sigma_sched) - 1 - if 'sigma_sched' in inspect.signature(self.func).parameters: - extra_params_kwargs['sigma_sched'] = sigma_sched - if 'sigmas' in inspect.signature(self.func).parameters: - extra_params_kwargs['sigmas'] = sigma_sched - - self.model_wrap_cfg.init_latent = x - self.last_latent = x - - samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args={ - 'cond': conditioning, - 'image_cond': image_conditioning, - 'uncond': unconditional_conditioning, - 'cond_scale': p.cfg_scale - }, disable=False, callback=self.callback_state, **extra_params_kwargs)) - - return samples - - def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning = None): - steps = steps or p.steps - - sigmas = self.get_sigmas(p, steps) - - x = x * sigmas[0] - - extra_params_kwargs = self.initialize(p) - if 'sigma_min' in inspect.signature(self.func).parameters: - extra_params_kwargs['sigma_min'] = self.model_wrap.sigmas[0].item() - extra_params_kwargs['sigma_max'] = self.model_wrap.sigmas[-1].item() - if 'n' in inspect.signature(self.func).parameters: - extra_params_kwargs['n'] = steps - else: - extra_params_kwargs['sigmas'] = sigmas - - self.last_latent = x - samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={ - 'cond': conditioning, - 'image_cond': image_conditioning, - 'uncond': unconditional_conditioning, - 'cond_scale': p.cfg_scale - }, disable=False, callback=self.callback_state, **extra_params_kwargs)) - - return samples - diff --git a/modules/sd_samplers_compvis.py b/modules/sd_samplers_compvis.py index a7910b56..3d35ff72 100644 --- a/modules/sd_samplers_compvis.py +++ b/modules/sd_samplers_compvis.py @@ -1,150 +1,10 @@ -from collections import namedtuple, deque +import math + import numpy as np -from math import floor import torch -import tqdm -from PIL import Image -import inspect -import k_diffusion.sampling -import torchsde._brownian.brownian_interval -import ldm.models.diffusion.ddim -import ldm.models.diffusion.plms -from modules import prompt_parser, devices, processing, images, sd_vae_approx -from modules.shared import opts, cmd_opts, state -import modules.shared as shared -from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback - - -SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options']) - -samplers_k_diffusion = [ - ('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {}), - ('Euler', 'sample_euler', ['k_euler'], {}), - ('LMS', 'sample_lms', ['k_lms'], {}), - ('Heun', 'sample_heun', ['k_heun'], {}), - ('DPM2', 'sample_dpm_2', ['k_dpm_2'], {'discard_next_to_last_sigma': True}), - ('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'discard_next_to_last_sigma': True}), - ('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {}), - ('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}), - ('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {}), - ('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {}), - ('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {}), - ('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}), - ('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True}), - ('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True}), - ('DPM++ 2S a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras'}), - ('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}), - ('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras'}), -] - -samplers_data_k_diffusion = [ - SamplerData(label, lambda model, funcname=funcname: KDiffusionSampler(funcname, model), aliases, options) - for label, funcname, aliases, options in samplers_k_diffusion - if hasattr(k_diffusion.sampling, funcname) -] - -all_samplers = [ - *samplers_data_k_diffusion, - SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {}), - SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {}), -] -all_samplers_map = {x.name: x for x in all_samplers} - -samplers = [] -samplers_for_img2img = [] -samplers_map = {} - - -def create_sampler(name, model): - if name is not None: - config = all_samplers_map.get(name, None) - else: - config = all_samplers[0] - - assert config is not None, f'bad sampler name: {name}' - - sampler = config.constructor(model) - sampler.config = config - - return sampler - - -def set_samplers(): - global samplers, samplers_for_img2img - - hidden = set(opts.hide_samplers) - hidden_img2img = set(opts.hide_samplers + ['PLMS']) - - samplers = [x for x in all_samplers if x.name not in hidden] - samplers_for_img2img = [x for x in all_samplers if x.name not in hidden_img2img] - - samplers_map.clear() - for sampler in all_samplers: - samplers_map[sampler.name.lower()] = sampler.name - for alias in sampler.aliases: - samplers_map[alias.lower()] = sampler.name - - -set_samplers() - -sampler_extra_params = { - 'sample_euler': ['s_churn', 's_tmin', 's_tmax', 's_noise'], - 'sample_heun': ['s_churn', 's_tmin', 's_tmax', 's_noise'], - 'sample_dpm_2': ['s_churn', 's_tmin', 's_tmax', 's_noise'], -} - - -def setup_img2img_steps(p, steps=None): - if opts.img2img_fix_steps or steps is not None: - requested_steps = (steps or p.steps) - steps = int(requested_steps / min(p.denoising_strength, 0.999)) if p.denoising_strength > 0 else 0 - t_enc = requested_steps - 1 - else: - steps = p.steps - t_enc = int(min(p.denoising_strength, 0.999) * steps) - - return steps, t_enc - - -approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2} - - -def single_sample_to_image(sample, approximation=None): - if approximation is None: - approximation = approximation_indexes.get(opts.show_progress_type, 0) - - if approximation == 2: - x_sample = sd_vae_approx.cheap_approximation(sample) - elif approximation == 1: - x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach() - else: - x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0] - - x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0) - x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) - x_sample = x_sample.astype(np.uint8) - return Image.fromarray(x_sample) - - -def sample_to_image(samples, index=0, approximation=None): - return single_sample_to_image(samples[index], approximation) - - -def samples_to_image_grid(samples, approximation=None): - return images.image_grid([single_sample_to_image(sample, approximation) for sample in samples]) - - -def store_latent(decoded): - state.current_latent = decoded - - if opts.live_previews_enable and opts.show_progress_every_n_steps > 0 and shared.state.sampling_step % opts.show_progress_every_n_steps == 0: - if not shared.parallel_processing_allowed: - shared.state.assign_current_image(sample_to_image(decoded)) - - -class InterruptedException(BaseException): - pass +from modules.shared import state +from modules import sd_samplers_common, prompt_parser, shared class VanillaStableDiffusionSampler: @@ -174,15 +34,15 @@ class VanillaStableDiffusionSampler: try: return func() - except InterruptedException: + except sd_samplers_common.InterruptedException: return self.last_latent def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs): if state.interrupted or state.skipped: - raise InterruptedException + raise sd_samplers_common.InterruptedException if self.stop_at is not None and self.step > self.stop_at: - raise InterruptedException + raise sd_samplers_common.InterruptedException # Have to unwrap the inpainting conditioning here to perform pre-processing image_conditioning = None @@ -224,7 +84,7 @@ class VanillaStableDiffusionSampler: else: self.last_latent = res[1] - store_latent(self.last_latent) + sd_samplers_common.store_latent(self.last_latent) self.step += 1 state.sampling_step = self.step @@ -233,7 +93,7 @@ class VanillaStableDiffusionSampler: return res def initialize(self, p): - self.eta = p.eta if p.eta is not None else opts.eta_ddim + self.eta = p.eta if p.eta is not None else shared.opts.eta_ddim for fieldname in ['p_sample_ddim', 'p_sample_plms']: if hasattr(self.sampler, fieldname): @@ -245,13 +105,13 @@ class VanillaStableDiffusionSampler: def adjust_steps_if_invalid(self, p, num_steps): if (self.config.name == 'DDIM' and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS'): valid_step = 999 / (1000 // num_steps) - if valid_step == floor(valid_step): + if valid_step == math.floor(valid_step): return int(valid_step) + 1 return num_steps def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): - steps, t_enc = setup_img2img_steps(p, steps) + steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps) steps = self.adjust_steps_if_invalid(p, steps) self.initialize(p) @@ -289,264 +149,3 @@ class VanillaStableDiffusionSampler: samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0]) return samples_ddim - - -class CFGDenoiser(torch.nn.Module): - def __init__(self, model): - super().__init__() - self.inner_model = model - self.mask = None - self.nmask = None - self.init_latent = None - self.step = 0 - - def combine_denoised(self, x_out, conds_list, uncond, cond_scale): - denoised_uncond = x_out[-uncond.shape[0]:] - denoised = torch.clone(denoised_uncond) - - for i, conds in enumerate(conds_list): - for cond_index, weight in conds: - denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale) - - return denoised - - def forward(self, x, sigma, uncond, cond, cond_scale, image_cond): - if state.interrupted or state.skipped: - raise InterruptedException - - conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step) - uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step) - - batch_size = len(conds_list) - repeats = [len(conds_list[i]) for i in range(batch_size)] - - x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x]) - image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond]) - sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma]) - - denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps) - cfg_denoiser_callback(denoiser_params) - x_in = denoiser_params.x - image_cond_in = denoiser_params.image_cond - sigma_in = denoiser_params.sigma - - if tensor.shape[1] == uncond.shape[1]: - cond_in = torch.cat([tensor, uncond]) - - if shared.batch_cond_uncond: - x_out = self.inner_model(x_in, sigma_in, cond={"c_crossattn": [cond_in], "c_concat": [image_cond_in]}) - else: - x_out = torch.zeros_like(x_in) - for batch_offset in range(0, x_out.shape[0], batch_size): - a = batch_offset - b = a + batch_size - x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": [cond_in[a:b]], "c_concat": [image_cond_in[a:b]]}) - else: - x_out = torch.zeros_like(x_in) - batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size - for batch_offset in range(0, tensor.shape[0], batch_size): - a = batch_offset - b = min(a + batch_size, tensor.shape[0]) - x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": [tensor[a:b]], "c_concat": [image_cond_in[a:b]]}) - - x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond={"c_crossattn": [uncond], "c_concat": [image_cond_in[-uncond.shape[0]:]]}) - - devices.test_for_nans(x_out, "unet") - - if opts.live_preview_content == "Prompt": - store_latent(x_out[0:uncond.shape[0]]) - elif opts.live_preview_content == "Negative prompt": - store_latent(x_out[-uncond.shape[0]:]) - - denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale) - - if self.mask is not None: - denoised = self.init_latent * self.mask + self.nmask * denoised - - self.step += 1 - - return denoised - - -class TorchHijack: - def __init__(self, sampler_noises): - # Using a deque to efficiently receive the sampler_noises in the same order as the previous index-based - # implementation. - self.sampler_noises = deque(sampler_noises) - - def __getattr__(self, item): - if item == 'randn_like': - return self.randn_like - - if hasattr(torch, item): - return getattr(torch, item) - - raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item)) - - def randn_like(self, x): - if self.sampler_noises: - noise = self.sampler_noises.popleft() - if noise.shape == x.shape: - return noise - - if x.device.type == 'mps': - return torch.randn_like(x, device=devices.cpu).to(x.device) - else: - return torch.randn_like(x) - - -# MPS fix for randn in torchsde -def torchsde_randn(size, dtype, device, seed): - if device.type == 'mps': - generator = torch.Generator(devices.cpu).manual_seed(int(seed)) - return torch.randn(size, dtype=dtype, device=devices.cpu, generator=generator).to(device) - else: - generator = torch.Generator(device).manual_seed(int(seed)) - return torch.randn(size, dtype=dtype, device=device, generator=generator) - - -torchsde._brownian.brownian_interval._randn = torchsde_randn - - -class KDiffusionSampler: - def __init__(self, funcname, sd_model): - denoiser = k_diffusion.external.CompVisVDenoiser if sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser - - self.model_wrap = denoiser(sd_model, quantize=shared.opts.enable_quantization) - self.funcname = funcname - self.func = getattr(k_diffusion.sampling, self.funcname) - self.extra_params = sampler_extra_params.get(funcname, []) - self.model_wrap_cfg = CFGDenoiser(self.model_wrap) - self.sampler_noises = None - self.stop_at = None - self.eta = None - self.default_eta = 1.0 - self.config = None - self.last_latent = None - - self.conditioning_key = sd_model.model.conditioning_key - - def callback_state(self, d): - step = d['i'] - latent = d["denoised"] - if opts.live_preview_content == "Combined": - store_latent(latent) - self.last_latent = latent - - if self.stop_at is not None and step > self.stop_at: - raise InterruptedException - - state.sampling_step = step - shared.total_tqdm.update() - - def launch_sampling(self, steps, func): - state.sampling_steps = steps - state.sampling_step = 0 - - try: - return func() - except InterruptedException: - return self.last_latent - - def number_of_needed_noises(self, p): - return p.steps - - def initialize(self, p): - self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None - self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None - self.model_wrap_cfg.step = 0 - self.eta = p.eta or opts.eta_ancestral - - k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else []) - - extra_params_kwargs = {} - for param_name in self.extra_params: - if hasattr(p, param_name) and param_name in inspect.signature(self.func).parameters: - extra_params_kwargs[param_name] = getattr(p, param_name) - - if 'eta' in inspect.signature(self.func).parameters: - extra_params_kwargs['eta'] = self.eta - - return extra_params_kwargs - - def get_sigmas(self, p, steps): - discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False) - if opts.always_discard_next_to_last_sigma and not discard_next_to_last_sigma: - discard_next_to_last_sigma = True - p.extra_generation_params["Discard penultimate sigma"] = True - - steps += 1 if discard_next_to_last_sigma else 0 - - if p.sampler_noise_scheduler_override: - sigmas = p.sampler_noise_scheduler_override(steps) - elif self.config is not None and self.config.options.get('scheduler', None) == 'karras': - sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item()) - - sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, device=shared.device) - else: - sigmas = self.model_wrap.get_sigmas(steps) - - if discard_next_to_last_sigma: - sigmas = torch.cat([sigmas[:-2], sigmas[-1:]]) - - return sigmas - - def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): - steps, t_enc = setup_img2img_steps(p, steps) - - sigmas = self.get_sigmas(p, steps) - - sigma_sched = sigmas[steps - t_enc - 1:] - xi = x + noise * sigma_sched[0] - - extra_params_kwargs = self.initialize(p) - if 'sigma_min' in inspect.signature(self.func).parameters: - ## last sigma is zero which isn't allowed by DPM Fast & Adaptive so taking value before last - extra_params_kwargs['sigma_min'] = sigma_sched[-2] - if 'sigma_max' in inspect.signature(self.func).parameters: - extra_params_kwargs['sigma_max'] = sigma_sched[0] - if 'n' in inspect.signature(self.func).parameters: - extra_params_kwargs['n'] = len(sigma_sched) - 1 - if 'sigma_sched' in inspect.signature(self.func).parameters: - extra_params_kwargs['sigma_sched'] = sigma_sched - if 'sigmas' in inspect.signature(self.func).parameters: - extra_params_kwargs['sigmas'] = sigma_sched - - self.model_wrap_cfg.init_latent = x - self.last_latent = x - - samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args={ - 'cond': conditioning, - 'image_cond': image_conditioning, - 'uncond': unconditional_conditioning, - 'cond_scale': p.cfg_scale - }, disable=False, callback=self.callback_state, **extra_params_kwargs)) - - return samples - - def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning = None): - steps = steps or p.steps - - sigmas = self.get_sigmas(p, steps) - - x = x * sigmas[0] - - extra_params_kwargs = self.initialize(p) - if 'sigma_min' in inspect.signature(self.func).parameters: - extra_params_kwargs['sigma_min'] = self.model_wrap.sigmas[0].item() - extra_params_kwargs['sigma_max'] = self.model_wrap.sigmas[-1].item() - if 'n' in inspect.signature(self.func).parameters: - extra_params_kwargs['n'] = steps - else: - extra_params_kwargs['sigmas'] = sigmas - - self.last_latent = x - samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={ - 'cond': conditioning, - 'image_cond': image_conditioning, - 'uncond': unconditional_conditioning, - 'cond_scale': p.cfg_scale - }, disable=False, callback=self.callback_state, **extra_params_kwargs)) - - return samples - From f4d0538bf2f6430b145bb26a294b7f82b50f031a Mon Sep 17 00:00:00 2001 From: Andrey <16777216c@gmail.com> Date: Mon, 30 Jan 2023 09:51:23 +0300 Subject: [PATCH 138/151] Split history sd_samplers.py to sd_samplers_kdiffusion.py --- modules/{sd_samplers.py => sd_samplers_kdiffusion.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename modules/{sd_samplers.py => sd_samplers_kdiffusion.py} (100%) diff --git a/modules/sd_samplers.py b/modules/sd_samplers_kdiffusion.py similarity index 100% rename from modules/sd_samplers.py rename to modules/sd_samplers_kdiffusion.py From 2db8ed32cd71fab68169dcb1b49998917190e3c7 Mon Sep 17 00:00:00 2001 From: Andrey <16777216c@gmail.com> Date: Mon, 30 Jan 2023 09:51:23 +0300 Subject: [PATCH 139/151] Split history sd_samplers.py to sd_samplers_kdiffusion.py --- modules/{sd_samplers.py => temp} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename modules/{sd_samplers.py => temp} (100%) diff --git a/modules/sd_samplers.py b/modules/temp similarity index 100% rename from modules/sd_samplers.py rename to modules/temp From 274474105a5166a985a47508ffd0695db41623a5 Mon Sep 17 00:00:00 2001 From: Andrey <16777216c@gmail.com> Date: Mon, 30 Jan 2023 09:51:23 +0300 Subject: [PATCH 140/151] Split history sd_samplers.py to sd_samplers_kdiffusion.py --- modules/{temp => sd_samplers.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename modules/{temp => sd_samplers.py} (100%) diff --git a/modules/temp b/modules/sd_samplers.py similarity index 100% rename from modules/temp rename to modules/sd_samplers.py From 4df63d2d197f26181758b5108f003f225fe84874 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 30 Jan 2023 10:11:30 +0300 Subject: [PATCH 141/151] split samplers into one more files for k-diffusion --- modules/sd_samplers.py | 302 +----------------------------- modules/sd_samplers_common.py | 3 +- modules/sd_samplers_compvis.py | 8 + modules/sd_samplers_kdiffusion.py | 57 +----- 4 files changed, 22 insertions(+), 348 deletions(-) diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index 9a29f1ae..28c2136f 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -1,49 +1,11 @@ -from collections import deque -import torch -import inspect -import k_diffusion.sampling -import ldm.models.diffusion.ddim -import ldm.models.diffusion.plms -from modules import prompt_parser, devices, sd_samplers_common, sd_samplers_compvis - -from modules.shared import opts, state -import modules.shared as shared -from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback +from modules import sd_samplers_compvis, sd_samplers_kdiffusion, shared # imports for functions that previously were here and are used by other modules from modules.sd_samplers_common import samples_to_image_grid, sample_to_image - -samplers_k_diffusion = [ - ('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {}), - ('Euler', 'sample_euler', ['k_euler'], {}), - ('LMS', 'sample_lms', ['k_lms'], {}), - ('Heun', 'sample_heun', ['k_heun'], {}), - ('DPM2', 'sample_dpm_2', ['k_dpm_2'], {'discard_next_to_last_sigma': True}), - ('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'discard_next_to_last_sigma': True}), - ('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {}), - ('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}), - ('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {}), - ('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {}), - ('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {}), - ('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}), - ('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True}), - ('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True}), - ('DPM++ 2S a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras'}), - ('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}), - ('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras'}), -] - -samplers_data_k_diffusion = [ - sd_samplers_common.SamplerData(label, lambda model, funcname=funcname: KDiffusionSampler(funcname, model), aliases, options) - for label, funcname, aliases, options in samplers_k_diffusion - if hasattr(k_diffusion.sampling, funcname) -] - all_samplers = [ - *samplers_data_k_diffusion, - sd_samplers_common.SamplerData('DDIM', lambda model: sd_samplers_compvis.VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {}), - sd_samplers_common.SamplerData('PLMS', lambda model: sd_samplers_compvis.VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {}), + *sd_samplers_kdiffusion.samplers_data_k_diffusion, + *sd_samplers_compvis.samplers_data_compvis, ] all_samplers_map = {x.name: x for x in all_samplers} @@ -69,8 +31,8 @@ def create_sampler(name, model): def set_samplers(): global samplers, samplers_for_img2img - hidden = set(opts.hide_samplers) - hidden_img2img = set(opts.hide_samplers + ['PLMS']) + hidden = set(shared.opts.hide_samplers) + hidden_img2img = set(shared.opts.hide_samplers + ['PLMS']) samplers = [x for x in all_samplers if x.name not in hidden] samplers_for_img2img = [x for x in all_samplers if x.name not in hidden_img2img] @@ -83,257 +45,3 @@ def set_samplers(): set_samplers() - -sampler_extra_params = { - 'sample_euler': ['s_churn', 's_tmin', 's_tmax', 's_noise'], - 'sample_heun': ['s_churn', 's_tmin', 's_tmax', 's_noise'], - 'sample_dpm_2': ['s_churn', 's_tmin', 's_tmax', 's_noise'], -} - - -class CFGDenoiser(torch.nn.Module): - def __init__(self, model): - super().__init__() - self.inner_model = model - self.mask = None - self.nmask = None - self.init_latent = None - self.step = 0 - - def combine_denoised(self, x_out, conds_list, uncond, cond_scale): - denoised_uncond = x_out[-uncond.shape[0]:] - denoised = torch.clone(denoised_uncond) - - for i, conds in enumerate(conds_list): - for cond_index, weight in conds: - denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale) - - return denoised - - def forward(self, x, sigma, uncond, cond, cond_scale, image_cond): - if state.interrupted or state.skipped: - raise sd_samplers_common.InterruptedException - - conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step) - uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step) - - batch_size = len(conds_list) - repeats = [len(conds_list[i]) for i in range(batch_size)] - - x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x]) - image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond]) - sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma]) - - denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps) - cfg_denoiser_callback(denoiser_params) - x_in = denoiser_params.x - image_cond_in = denoiser_params.image_cond - sigma_in = denoiser_params.sigma - - if tensor.shape[1] == uncond.shape[1]: - cond_in = torch.cat([tensor, uncond]) - - if shared.batch_cond_uncond: - x_out = self.inner_model(x_in, sigma_in, cond={"c_crossattn": [cond_in], "c_concat": [image_cond_in]}) - else: - x_out = torch.zeros_like(x_in) - for batch_offset in range(0, x_out.shape[0], batch_size): - a = batch_offset - b = a + batch_size - x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": [cond_in[a:b]], "c_concat": [image_cond_in[a:b]]}) - else: - x_out = torch.zeros_like(x_in) - batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size - for batch_offset in range(0, tensor.shape[0], batch_size): - a = batch_offset - b = min(a + batch_size, tensor.shape[0]) - x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": [tensor[a:b]], "c_concat": [image_cond_in[a:b]]}) - - x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond={"c_crossattn": [uncond], "c_concat": [image_cond_in[-uncond.shape[0]:]]}) - - devices.test_for_nans(x_out, "unet") - - if opts.live_preview_content == "Prompt": - sd_samplers_common.store_latent(x_out[0:uncond.shape[0]]) - elif opts.live_preview_content == "Negative prompt": - sd_samplers_common.store_latent(x_out[-uncond.shape[0]:]) - - denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale) - - if self.mask is not None: - denoised = self.init_latent * self.mask + self.nmask * denoised - - self.step += 1 - - return denoised - - -class TorchHijack: - def __init__(self, sampler_noises): - # Using a deque to efficiently receive the sampler_noises in the same order as the previous index-based - # implementation. - self.sampler_noises = deque(sampler_noises) - - def __getattr__(self, item): - if item == 'randn_like': - return self.randn_like - - if hasattr(torch, item): - return getattr(torch, item) - - raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item)) - - def randn_like(self, x): - if self.sampler_noises: - noise = self.sampler_noises.popleft() - if noise.shape == x.shape: - return noise - - if x.device.type == 'mps': - return torch.randn_like(x, device=devices.cpu).to(x.device) - else: - return torch.randn_like(x) - - -class KDiffusionSampler: - def __init__(self, funcname, sd_model): - denoiser = k_diffusion.external.CompVisVDenoiser if sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser - - self.model_wrap = denoiser(sd_model, quantize=shared.opts.enable_quantization) - self.funcname = funcname - self.func = getattr(k_diffusion.sampling, self.funcname) - self.extra_params = sampler_extra_params.get(funcname, []) - self.model_wrap_cfg = CFGDenoiser(self.model_wrap) - self.sampler_noises = None - self.stop_at = None - self.eta = None - self.default_eta = 1.0 - self.config = None - self.last_latent = None - - self.conditioning_key = sd_model.model.conditioning_key - - def callback_state(self, d): - step = d['i'] - latent = d["denoised"] - if opts.live_preview_content == "Combined": - sd_samplers_common.store_latent(latent) - self.last_latent = latent - - if self.stop_at is not None and step > self.stop_at: - raise sd_samplers_common.InterruptedException - - state.sampling_step = step - shared.total_tqdm.update() - - def launch_sampling(self, steps, func): - state.sampling_steps = steps - state.sampling_step = 0 - - try: - return func() - except sd_samplers_common.InterruptedException: - return self.last_latent - - def number_of_needed_noises(self, p): - return p.steps - - def initialize(self, p): - self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None - self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None - self.model_wrap_cfg.step = 0 - self.eta = p.eta or opts.eta_ancestral - - k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else []) - - extra_params_kwargs = {} - for param_name in self.extra_params: - if hasattr(p, param_name) and param_name in inspect.signature(self.func).parameters: - extra_params_kwargs[param_name] = getattr(p, param_name) - - if 'eta' in inspect.signature(self.func).parameters: - extra_params_kwargs['eta'] = self.eta - - return extra_params_kwargs - - def get_sigmas(self, p, steps): - discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False) - if opts.always_discard_next_to_last_sigma and not discard_next_to_last_sigma: - discard_next_to_last_sigma = True - p.extra_generation_params["Discard penultimate sigma"] = True - - steps += 1 if discard_next_to_last_sigma else 0 - - if p.sampler_noise_scheduler_override: - sigmas = p.sampler_noise_scheduler_override(steps) - elif self.config is not None and self.config.options.get('scheduler', None) == 'karras': - sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item()) - - sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, device=shared.device) - else: - sigmas = self.model_wrap.get_sigmas(steps) - - if discard_next_to_last_sigma: - sigmas = torch.cat([sigmas[:-2], sigmas[-1:]]) - - return sigmas - - def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): - steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps) - - sigmas = self.get_sigmas(p, steps) - - sigma_sched = sigmas[steps - t_enc - 1:] - xi = x + noise * sigma_sched[0] - - extra_params_kwargs = self.initialize(p) - if 'sigma_min' in inspect.signature(self.func).parameters: - ## last sigma is zero which isn't allowed by DPM Fast & Adaptive so taking value before last - extra_params_kwargs['sigma_min'] = sigma_sched[-2] - if 'sigma_max' in inspect.signature(self.func).parameters: - extra_params_kwargs['sigma_max'] = sigma_sched[0] - if 'n' in inspect.signature(self.func).parameters: - extra_params_kwargs['n'] = len(sigma_sched) - 1 - if 'sigma_sched' in inspect.signature(self.func).parameters: - extra_params_kwargs['sigma_sched'] = sigma_sched - if 'sigmas' in inspect.signature(self.func).parameters: - extra_params_kwargs['sigmas'] = sigma_sched - - self.model_wrap_cfg.init_latent = x - self.last_latent = x - - samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args={ - 'cond': conditioning, - 'image_cond': image_conditioning, - 'uncond': unconditional_conditioning, - 'cond_scale': p.cfg_scale - }, disable=False, callback=self.callback_state, **extra_params_kwargs)) - - return samples - - def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning = None): - steps = steps or p.steps - - sigmas = self.get_sigmas(p, steps) - - x = x * sigmas[0] - - extra_params_kwargs = self.initialize(p) - if 'sigma_min' in inspect.signature(self.func).parameters: - extra_params_kwargs['sigma_min'] = self.model_wrap.sigmas[0].item() - extra_params_kwargs['sigma_max'] = self.model_wrap.sigmas[-1].item() - if 'n' in inspect.signature(self.func).parameters: - extra_params_kwargs['n'] = steps - else: - extra_params_kwargs['sigmas'] = sigmas - - self.last_latent = x - samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={ - 'cond': conditioning, - 'image_cond': image_conditioning, - 'uncond': unconditional_conditioning, - 'cond_scale': p.cfg_scale - }, disable=False, callback=self.callback_state, **extra_params_kwargs)) - - return samples - diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index 5b06e341..3c03d442 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -1,4 +1,4 @@ -from collections import namedtuple, deque +from collections import namedtuple import numpy as np import torch from PIL import Image @@ -64,6 +64,7 @@ class InterruptedException(BaseException): # MPS fix for randn in torchsde +# XXX move this to separate file for MPS def torchsde_randn(size, dtype, device, seed): if device.type == 'mps': generator = torch.Generator(devices.cpu).manual_seed(int(seed)) diff --git a/modules/sd_samplers_compvis.py b/modules/sd_samplers_compvis.py index 3d35ff72..88541193 100644 --- a/modules/sd_samplers_compvis.py +++ b/modules/sd_samplers_compvis.py @@ -1,4 +1,6 @@ import math +import ldm.models.diffusion.ddim +import ldm.models.diffusion.plms import numpy as np import torch @@ -7,6 +9,12 @@ from modules.shared import state from modules import sd_samplers_common, prompt_parser, shared +samplers_data_compvis = [ + sd_samplers_common.SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {}), + sd_samplers_common.SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {}), +] + + class VanillaStableDiffusionSampler: def __init__(self, constructor, sd_model): self.sampler = constructor(sd_model) diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index 9a29f1ae..adb6883e 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -2,18 +2,12 @@ from collections import deque import torch import inspect import k_diffusion.sampling -import ldm.models.diffusion.ddim -import ldm.models.diffusion.plms from modules import prompt_parser, devices, sd_samplers_common, sd_samplers_compvis from modules.shared import opts, state import modules.shared as shared from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback -# imports for functions that previously were here and are used by other modules -from modules.sd_samplers_common import samples_to_image_grid, sample_to_image - - samplers_k_diffusion = [ ('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {}), ('Euler', 'sample_euler', ['k_euler'], {}), @@ -40,50 +34,6 @@ samplers_data_k_diffusion = [ if hasattr(k_diffusion.sampling, funcname) ] -all_samplers = [ - *samplers_data_k_diffusion, - sd_samplers_common.SamplerData('DDIM', lambda model: sd_samplers_compvis.VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {}), - sd_samplers_common.SamplerData('PLMS', lambda model: sd_samplers_compvis.VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {}), -] -all_samplers_map = {x.name: x for x in all_samplers} - -samplers = [] -samplers_for_img2img = [] -samplers_map = {} - - -def create_sampler(name, model): - if name is not None: - config = all_samplers_map.get(name, None) - else: - config = all_samplers[0] - - assert config is not None, f'bad sampler name: {name}' - - sampler = config.constructor(model) - sampler.config = config - - return sampler - - -def set_samplers(): - global samplers, samplers_for_img2img - - hidden = set(opts.hide_samplers) - hidden_img2img = set(opts.hide_samplers + ['PLMS']) - - samplers = [x for x in all_samplers if x.name not in hidden] - samplers_for_img2img = [x for x in all_samplers if x.name not in hidden_img2img] - - samplers_map.clear() - for sampler in all_samplers: - samplers_map[sampler.name.lower()] = sampler.name - for alias in sampler.aliases: - samplers_map[alias.lower()] = sampler.name - - -set_samplers() - sampler_extra_params = { 'sample_euler': ['s_churn', 's_tmin', 's_tmax', 's_noise'], 'sample_heun': ['s_churn', 's_tmin', 's_tmax', 's_noise'], @@ -92,6 +42,13 @@ sampler_extra_params = { class CFGDenoiser(torch.nn.Module): + """ + Classifier free guidance denoiser. A wrapper for stable diffusion model (specifically for unet) + that can take a noisy picture and produce a noise-free picture using two guidances (prompts) + instead of one. Originally, the second prompt is just an empty string, but we use non-empty + negative prompt. + """ + def __init__(self, model): super().__init__() self.inner_model = model From 040ec7a80e23d340efe1108b9de5ead62d9011a9 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 30 Jan 2023 10:47:09 +0300 Subject: [PATCH 142/151] make the program read Eta and Eta DDIM from generation parameters --- modules/generation_parameters_copypaste.py | 2 ++ modules/processing.py | 1 - modules/sd_samplers_compvis.py | 3 ++- modules/sd_samplers_kdiffusion.py | 8 +++++--- 4 files changed, 9 insertions(+), 5 deletions(-) diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index 2a10524f..7ee8ee10 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -293,6 +293,8 @@ infotext_to_setting_name_mapping = [ ('Model hash', 'sd_model_checkpoint'), ('ENSD', 'eta_noise_seed_delta'), ('Noise multiplier', 'initial_noise_multiplier'), + ('Eta', 'eta_ancestral'), + ('Eta DDIM', 'eta_ddim'), ] diff --git a/modules/processing.py b/modules/processing.py index 2d295932..e544c2e1 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -455,7 +455,6 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter "Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"), "Denoising strength": getattr(p, 'denoising_strength', None), "Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None, - "Eta": (None if p.sampler is None or p.sampler.eta == p.sampler.default_eta else p.sampler.eta), "Clip skip": None if clip_skip <= 1 else clip_skip, "ENSD": None if opts.eta_noise_seed_delta == 0 else opts.eta_noise_seed_delta, } diff --git a/modules/sd_samplers_compvis.py b/modules/sd_samplers_compvis.py index 88541193..d03131cd 100644 --- a/modules/sd_samplers_compvis.py +++ b/modules/sd_samplers_compvis.py @@ -27,7 +27,6 @@ class VanillaStableDiffusionSampler: self.step = 0 self.stop_at = None self.eta = None - self.default_eta = 0.0 self.config = None self.last_latent = None @@ -102,6 +101,8 @@ class VanillaStableDiffusionSampler: def initialize(self, p): self.eta = p.eta if p.eta is not None else shared.opts.eta_ddim + if self.eta != 0.0: + p.extra_generation_params["Eta DDIM"] = self.eta for fieldname in ['p_sample_ddim', 'p_sample_plms']: if hasattr(self.sampler, fieldname): diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index adb6883e..aa7f106b 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -2,7 +2,7 @@ from collections import deque import torch import inspect import k_diffusion.sampling -from modules import prompt_parser, devices, sd_samplers_common, sd_samplers_compvis +from modules import prompt_parser, devices, sd_samplers_common from modules.shared import opts, state import modules.shared as shared @@ -164,7 +164,6 @@ class KDiffusionSampler: self.sampler_noises = None self.stop_at = None self.eta = None - self.default_eta = 1.0 self.config = None self.last_latent = None @@ -199,7 +198,7 @@ class KDiffusionSampler: self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None self.model_wrap_cfg.step = 0 - self.eta = p.eta or opts.eta_ancestral + self.eta = p.eta if p.eta is not None else opts.eta_ancestral k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else []) @@ -209,6 +208,9 @@ class KDiffusionSampler: extra_params_kwargs[param_name] = getattr(p, param_name) if 'eta' in inspect.signature(self.func).parameters: + if self.eta != 1.0: + p.extra_generation_params["Eta"] = self.eta + extra_params_kwargs['eta'] = self.eta return extra_params_kwargs From ab059b6e4863eaa5e118a2043192584e6df51ed4 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 30 Jan 2023 10:52:15 +0300 Subject: [PATCH 143/151] make the program read Discard penultimate sigma from generation parameters --- modules/generation_parameters_copypaste.py | 1 + 1 file changed, 1 insertion(+) diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index 7ee8ee10..fc9e17aa 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -295,6 +295,7 @@ infotext_to_setting_name_mapping = [ ('Noise multiplier', 'initial_noise_multiplier'), ('Eta', 'eta_ancestral'), ('Eta DDIM', 'eta_ddim'), + ('Discard penultimate sigma', 'always_discard_next_to_last_sigma') ] From aa4688eb8345de583070ca9ddb4c6f585f06762b Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 30 Jan 2023 13:29:44 +0300 Subject: [PATCH 144/151] disable EMA weights for instructpix2pix model, whcih should get memory usage as well as image quality to what it was before d2ac95fa7b2a8d0bcc5361ee16dba9cbb81ff8b2 --- configs/instruct-pix2pix.yaml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/configs/instruct-pix2pix.yaml b/configs/instruct-pix2pix.yaml index 437ddcef..4e896879 100644 --- a/configs/instruct-pix2pix.yaml +++ b/configs/instruct-pix2pix.yaml @@ -20,8 +20,7 @@ model: conditioning_key: hybrid monitor: val/loss_simple_ema scale_factor: 0.18215 - use_ema: true - load_ema: true + use_ema: false scheduler_config: # 10000 warmup steps target: ldm.lr_scheduler.LambdaLinearScheduler From ee9fdf7f62984dc30770fb1a73e68736b319746f Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 30 Jan 2023 14:56:28 +0300 Subject: [PATCH 145/151] Add --skip-version-check to disable messages asking users to upgrade torch. --- modules/shared.py | 2 ++ webui.py | 10 +++++++++- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/modules/shared.py b/modules/shared.py index 96a2572f..69634fd8 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -105,6 +105,8 @@ parser.add_argument("--tls-keyfile", type=str, help="Partially enables TLS, requ parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, requires --tls-keyfile to fully function", default=None) parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None) parser.add_argument("--gradio-queue", action='store_true', help="Uses gradio queue; experimental option; breaks restart UI button") +parser.add_argument("--skip-version-check", action='store_true', help="Do not check versions of torch and xformers") + script_loading.preload_extensions(extensions.extensions_dir, parser) diff --git a/webui.py b/webui.py index 0d0b8364..5b5c2139 100644 --- a/webui.py +++ b/webui.py @@ -52,6 +52,9 @@ else: def check_versions(): + if shared.cmd_opts.skip_version_check: + return + expected_torch_version = "1.13.1" if version.parse(torch.__version__) < version.parse(expected_torch_version): @@ -59,7 +62,10 @@ def check_versions(): You are running torch {torch.__version__}. The program is tested to work with torch {expected_torch_version}. To reinstall the desired version, run with commandline flag --reinstall-torch. -Beware that this will cause a lot of large files to be downloaded. +Beware that this will cause a lot of large files to be downloaded, as well as +there are reports of issues with training tab on the latest version. + +Use --skip-version-check commandline argument to disable this check. """.strip()) expected_xformers_version = "0.0.16rc425" @@ -71,6 +77,8 @@ Beware that this will cause a lot of large files to be downloaded. You are running xformers {xformers.__version__}. The program is tested to work with xformers {expected_xformers_version}. To reinstall the desired version, run with commandline flag --reinstall-xformers. + +Use --skip-version-check commandline argument to disable this check. """.strip()) From 19de2a626b92bcfe83a97477f20d0faf9b3204c0 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 30 Jan 2023 15:48:09 +0300 Subject: [PATCH 146/151] make linux launch.py use XFORMERS_PACKAGE var too; thanks, acncagua --- launch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/launch.py b/launch.py index 25909469..c44c48fa 100644 --- a/launch.py +++ b/launch.py @@ -290,7 +290,7 @@ def prepare_environment(): if not is_installed("xformers"): exit(0) elif platform.system() == "Linux": - run_pip("install xformers==0.0.16rc425", "xformers") + run_pip(f"install xformers=={xformers_package}", "xformers") if not is_installed("pyngrok") and ngrok: run_pip("install pyngrok", "ngrok") From 2c1bb46c7ad5b4536f6587d327a03f0ff7811c5d Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 30 Jan 2023 18:48:10 +0300 Subject: [PATCH 147/151] amend the error in previous commit --- launch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/launch.py b/launch.py index c44c48fa..9fd766d1 100644 --- a/launch.py +++ b/launch.py @@ -290,7 +290,7 @@ def prepare_environment(): if not is_installed("xformers"): exit(0) elif platform.system() == "Linux": - run_pip(f"install xformers=={xformers_package}", "xformers") + run_pip(f"install {xformers_package}", "xformers") if not is_installed("pyngrok") and ngrok: run_pip("install pyngrok", "ngrok") From bfe7e7f15fbceccd016957769cd5b5a26c82a45b Mon Sep 17 00:00:00 2001 From: Piotr Pomierski Date: Tue, 31 Jan 2023 01:51:07 +0100 Subject: [PATCH 148/151] Fix missing tooltip for 'Clear prompt' button --- javascript/hints.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/javascript/hints.js b/javascript/hints.js index 7b60b25e..75792d0d 100644 --- a/javascript/hints.js +++ b/javascript/hints.js @@ -17,7 +17,7 @@ titles = { "\u2199\ufe0f": "Read generation parameters from prompt or last generation if prompt is empty into user interface.", "\u{1f4c2}": "Open images output directory", "\u{1f4be}": "Save style", - "\U0001F5D1": "Clear prompt", + "\u{1f5d1}": "Clear prompt", "\u{1f4cb}": "Apply selected styles to current prompt", "\u{1f4d2}": "Paste available values into the field", "\u{1f3b4}": "Show extra networks", From 0426b3478937e54446337cf435ed3f548688b120 Mon Sep 17 00:00:00 2001 From: Joey Sanchez Date: Mon, 30 Jan 2023 21:46:13 -0500 Subject: [PATCH 149/151] Adding default true to use_original_name_batch as images should by default hold the same name to help keep sequenced images in their correct order --- modules/shared.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/shared.py b/modules/shared.py index 69634fd8..5600d480 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -327,7 +327,7 @@ options_templates.update(options_section(('saving-images', "Saving images/grids" "jpeg_quality": OptionInfo(80, "Quality for saved jpeg images", gr.Slider, {"minimum": 1, "maximum": 100, "step": 1}), "export_for_4chan": OptionInfo(True, "If PNG image is larger than 4MB or any dimension is larger than 4000, downscale and save copy as JPG"), - "use_original_name_batch": OptionInfo(False, "Use original name for output filename during batch process in extras tab"), + "use_original_name_batch": OptionInfo(True, "Use original name for output filename during batch process in extras tab"), "use_upscaler_name_as_suffix": OptionInfo(False, "Use upscaler name as filename suffix in the extras tab"), "save_selected_only": OptionInfo(True, "When using 'Save' button, only save a single selected image"), "do_not_add_watermark": OptionInfo(False, "Do not add watermark to images"), From 7738c057ce938ca5c5a53a95e2023d3bcf14f06a Mon Sep 17 00:00:00 2001 From: brkirch Date: Wed, 1 Feb 2023 05:23:58 -0500 Subject: [PATCH 150/151] MPS fix is still needed :( Apparently I did not test with large enough images to trigger the bug with torch.narrow on MPS --- modules/devices.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/modules/devices.py b/modules/devices.py index 655ca1d3..f4afb897 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -207,3 +207,6 @@ if has_mps(): cumsum_needs_bool_fix = not torch.BoolTensor([True,True]).to(device=torch.device("mps"), dtype=torch.int64).equal(torch.BoolTensor([True,False]).to(torch.device("mps")).cumsum(0)) torch.cumsum = lambda input, *args, **kwargs: ( cumsum_fix(input, orig_cumsum, *args, **kwargs) ) torch.Tensor.cumsum = lambda self, *args, **kwargs: ( cumsum_fix(self, orig_Tensor_cumsum, *args, **kwargs) ) + orig_narrow = torch.narrow + torch.narrow = lambda *args, **kwargs: ( orig_narrow(*args, **kwargs).clone() ) + From 2217331cd1245d0bdda786a5dcaf4f7b843bd7e4 Mon Sep 17 00:00:00 2001 From: brkirch Date: Wed, 1 Feb 2023 06:20:19 -0500 Subject: [PATCH 151/151] Refactor MPS fixes to CondFunc --- modules/devices.py | 50 +++++++++++++--------------------------------- 1 file changed, 14 insertions(+), 36 deletions(-) diff --git a/modules/devices.py b/modules/devices.py index f4afb897..919048d0 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -2,6 +2,7 @@ import sys, os, shlex import contextlib import torch from modules import errors +from modules.sd_hijack_utils import CondFunc from packaging import version @@ -156,36 +157,7 @@ def test_for_nans(x, where): raise NansException(message) -# MPS workaround for https://github.com/pytorch/pytorch/issues/79383 -orig_tensor_to = torch.Tensor.to -def tensor_to_fix(self, *args, **kwargs): - if self.device.type != 'mps' and \ - ((len(args) > 0 and isinstance(args[0], torch.device) and args[0].type == 'mps') or \ - (isinstance(kwargs.get('device'), torch.device) and kwargs['device'].type == 'mps')): - self = self.contiguous() - return orig_tensor_to(self, *args, **kwargs) - - -# MPS workaround for https://github.com/pytorch/pytorch/issues/80800 -orig_layer_norm = torch.nn.functional.layer_norm -def layer_norm_fix(*args, **kwargs): - if len(args) > 0 and isinstance(args[0], torch.Tensor) and args[0].device.type == 'mps': - args = list(args) - args[0] = args[0].contiguous() - return orig_layer_norm(*args, **kwargs) - - -# MPS workaround for https://github.com/pytorch/pytorch/issues/90532 -orig_tensor_numpy = torch.Tensor.numpy -def numpy_fix(self, *args, **kwargs): - if self.requires_grad: - self = self.detach() - return orig_tensor_numpy(self, *args, **kwargs) - - # MPS workaround for https://github.com/pytorch/pytorch/issues/89784 -orig_cumsum = torch.cumsum -orig_Tensor_cumsum = torch.Tensor.cumsum def cumsum_fix(input, cumsum_func, *args, **kwargs): if input.device.type == 'mps': output_dtype = kwargs.get('dtype', input.dtype) @@ -199,14 +171,20 @@ def cumsum_fix(input, cumsum_func, *args, **kwargs): if has_mps(): if version.parse(torch.__version__) < version.parse("1.13"): # PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working - torch.Tensor.to = tensor_to_fix - torch.nn.functional.layer_norm = layer_norm_fix - torch.Tensor.numpy = numpy_fix + + # MPS workaround for https://github.com/pytorch/pytorch/issues/79383 + CondFunc('torch.Tensor.to', lambda orig_func, self, *args, **kwargs: orig_func(self.contiguous(), *args, **kwargs), + lambda _, self, *args, **kwargs: self.device.type != 'mps' and (args and isinstance(args[0], torch.device) and args[0].type == 'mps' or isinstance(kwargs.get('device'), torch.device) and kwargs['device'].type == 'mps')) + # MPS workaround for https://github.com/pytorch/pytorch/issues/80800 + CondFunc('torch.nn.functional.layer_norm', lambda orig_func, *args, **kwargs: orig_func(*([args[0].contiguous()] + list(args[1:])), **kwargs), + lambda _, *args, **kwargs: args and isinstance(args[0], torch.Tensor) and args[0].device.type == 'mps') + # MPS workaround for https://github.com/pytorch/pytorch/issues/90532 + CondFunc('torch.Tensor.numpy', lambda orig_func, self, *args, **kwargs: orig_func(self.detach(), *args, **kwargs), lambda _, self, *args, **kwargs: self.requires_grad) elif version.parse(torch.__version__) > version.parse("1.13.1"): cumsum_needs_int_fix = not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.ShortTensor([1,1]).to(torch.device("mps")).cumsum(0)) cumsum_needs_bool_fix = not torch.BoolTensor([True,True]).to(device=torch.device("mps"), dtype=torch.int64).equal(torch.BoolTensor([True,False]).to(torch.device("mps")).cumsum(0)) - torch.cumsum = lambda input, *args, **kwargs: ( cumsum_fix(input, orig_cumsum, *args, **kwargs) ) - torch.Tensor.cumsum = lambda self, *args, **kwargs: ( cumsum_fix(self, orig_Tensor_cumsum, *args, **kwargs) ) - orig_narrow = torch.narrow - torch.narrow = lambda *args, **kwargs: ( orig_narrow(*args, **kwargs).clone() ) + cumsum_fix_func = lambda orig_func, input, *args, **kwargs: cumsum_fix(input, orig_func, *args, **kwargs) + CondFunc('torch.cumsum', cumsum_fix_func, None) + CondFunc('torch.Tensor.cumsum', cumsum_fix_func, None) + CondFunc('torch.narrow', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).clone(), None)