Compare commits
No commits in common. "master" and "fix-mean-loss" have entirely different histories.
master
...
fix-mean-l
.github/ISSUE_TEMPLATE
README.mdartists.csvconfigs
extensions-builtin
LDSR
Lora
SwinIR/scripts
prompt-bracket-checker/javascript
roll-artist/scripts
html
javascript
edit-attention.jsextensions.jsextraNetworks.jshints.jshires_fix.jsimageviewer.jslocalization.jsprogressbar.jsui.js
launch.pymodules
api
artists.pycodeformer_model.pydeepbooru_model.pydevices.pyerrors.pyextensions.pyextra_networks.pyextra_networks_hypernet.pyextras.pygeneration_parameters_copypaste.pygfpgan_model.pyhashes.pyhypernetworks
images.pyimg2img.pyinterrogate.pymac_specific.pymodelloader.pymodels/diffusion
paths.pypostprocessing.pyprocessing.pyprogress.pyprompt_parser.pyrealesrgan_model.pyscript_callbacks.pyscript_loading.pyscripts.pyscripts_auto_postprocessing.pyscripts_postprocessing.pysd_disable_initialization.pysd_hijack.pysd_hijack_checkpoint.pysd_hijack_clip.pysd_hijack_inpainting.pysd_hijack_ip2p.pysd_hijack_optimizations.pysd_hijack_unet.pysd_hijack_utils.pysd_models.pysd_models_config.pysd_samplers.pysd_samplers_common.pysd_samplers_compvis.pysd_samplers_kdiffusion.pysd_vae.pyshared.pyshared_items.pystyles.pysub_quadratic_attention.pytextual_inversion
timer.pytxt2img.pyui.pyui_common.pyui_components.pyui_extensions.pyui_extra_networks.pyui_extra_networks_checkpoints.pyui_extra_networks_hypernets.pyui_extra_networks_textual_inversion.pyui_postprocessing.pyupscaler.pyscripts
29
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
29
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
|
@ -37,20 +37,20 @@ body:
|
|||
id: what-should
|
||||
attributes:
|
||||
label: What should have happened?
|
||||
description: Tell what you think the normal behavior should be
|
||||
description: tell what you think the normal behavior should be
|
||||
validations:
|
||||
required: true
|
||||
- type: input
|
||||
id: commit
|
||||
attributes:
|
||||
label: Commit where the problem happens
|
||||
description: Which commit are you running ? (Do not write *Latest version/repo/commit*, as this means nothing and will have changed by the time we read your issue. Rather, copy the **Commit** link at the bottom of the UI, or from the cmd/terminal if you can't launch it.)
|
||||
description: Which commit are you running ? (Do not write *Latest version/repo/commit*, as this means nothing and will have changed by the time we read your issue. Rather, copy the **Commit hash** shown in the cmd/terminal when you launch the UI)
|
||||
validations:
|
||||
required: true
|
||||
- type: dropdown
|
||||
id: platforms
|
||||
attributes:
|
||||
label: What platforms do you use to access the UI ?
|
||||
label: What platforms do you use to access UI ?
|
||||
multiple: true
|
||||
options:
|
||||
- Windows
|
||||
|
@ -74,27 +74,10 @@ body:
|
|||
id: cmdargs
|
||||
attributes:
|
||||
label: Command Line Arguments
|
||||
description: Are you using any launching parameters/command line arguments (modified webui-user .bat/.sh) ? If yes, please write them below. Write "No" otherwise.
|
||||
description: Are you using any launching parameters/command line arguments (modified webui-user.py) ? If yes, please write them below
|
||||
render: Shell
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
id: extensions
|
||||
attributes:
|
||||
label: List of extensions
|
||||
description: Are you using any extensions other than built-ins? If yes, provide a list, you can copy it at "Extensions" tab. Write "No" otherwise.
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
id: logs
|
||||
attributes:
|
||||
label: Console logs
|
||||
description: Please provide **full** cmd/terminal logs from the moment you started UI to the end of it, after your bug happened. If it's very long, provide a link to pastebin or similar service.
|
||||
render: Shell
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
id: misc
|
||||
attributes:
|
||||
label: Additional information
|
||||
description: Please provide us with any relevant additional info or context.
|
||||
label: Additional information, context and logs
|
||||
description: Please provide us with any relevant additional info, context or log output.
|
||||
|
|
19
README.md
19
README.md
|
@ -17,7 +17,7 @@ A browser interface based on Gradio library for Stable Diffusion.
|
|||
- a man in a (tuxedo:1.21) - alternative syntax
|
||||
- select text and press ctrl+up or ctrl+down to automatically adjust attention to selected text (code contributed by anonymous user)
|
||||
- Loopback, run img2img processing multiple times
|
||||
- X/Y/Z plot, a way to draw a 3 dimensional plot of images with different parameters
|
||||
- X/Y plot, a way to draw a 2 dimensional plot of images with different parameters
|
||||
- Textual Inversion
|
||||
- have as many embeddings as you want and use any names you like for them
|
||||
- use multiple embeddings with different numbers of vectors per token
|
||||
|
@ -49,9 +49,9 @@ A browser interface based on Gradio library for Stable Diffusion.
|
|||
- Running arbitrary python code from UI (must run with --allow-code to enable)
|
||||
- Mouseover hints for most UI elements
|
||||
- Possible to change defaults/mix/max/step values for UI elements via text config
|
||||
- Random artist button
|
||||
- Tiling support, a checkbox to create images that can be tiled like textures
|
||||
- Progress bar and live image generation preview
|
||||
- Can use a separate neural network to produce previews with almost none VRAM or compute requirement
|
||||
- Negative prompt, an extra text field that allows you to list what you don't want to see in generated image
|
||||
- Styles, a way to save part of prompt and easily apply them via dropdown later
|
||||
- Variations, a way to generate same image but with tiny differences
|
||||
|
@ -76,22 +76,13 @@ A browser interface based on Gradio library for Stable Diffusion.
|
|||
- hypernetworks and embeddings options
|
||||
- Preprocessing images: cropping, mirroring, autotagging using BLIP or deepdanbooru (for anime)
|
||||
- Clip skip
|
||||
- Hypernetworks
|
||||
- Loras (same as Hypernetworks but more pretty)
|
||||
- A sparate UI where you can choose, with preview, which embeddings, hypernetworks or Loras to add to your prompt.
|
||||
- Can select to load a different VAE from settings screen
|
||||
- Use Hypernetworks
|
||||
- Use VAEs
|
||||
- Estimated completion time in progress bar
|
||||
- API
|
||||
- Support for dedicated [inpainting model](https://github.com/runwayml/stable-diffusion#inpainting-with-stable-diffusion) by RunwayML.
|
||||
- via extension: [Aesthetic Gradients](https://github.com/AUTOMATIC1111/stable-diffusion-webui-aesthetic-gradients), a way to generate images with a specific aesthetic by using clip images embeds (implementation of [https://github.com/vicgalle/stable-diffusion-aesthetic-gradients](https://github.com/vicgalle/stable-diffusion-aesthetic-gradients))
|
||||
- [Stable Diffusion 2.0](https://github.com/Stability-AI/stablediffusion) support - see [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#stable-diffusion-20) for instructions
|
||||
- [Alt-Diffusion](https://arxiv.org/abs/2211.06679) support - see [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#alt-diffusion) for instructions
|
||||
- Now without any bad letters!
|
||||
- Load checkpoints in safetensors format
|
||||
- Eased resolution restriction: generated image's domension must be a multiple of 8 rather than 64
|
||||
- Now with a license!
|
||||
- Reorder elements in the UI from settings screen
|
||||
-
|
||||
|
||||
## Installation and Running
|
||||
Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for both [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended) and [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs.
|
||||
|
@ -155,8 +146,6 @@ Licenses for borrowed code can be found in `Settings -> Licenses` screen, and al
|
|||
- Idea for Composable Diffusion - https://github.com/energy-based-model/Compositional-Visual-Generation-with-Composable-Diffusion-Models-PyTorch
|
||||
- xformers - https://github.com/facebookresearch/xformers
|
||||
- DeepDanbooru - interrogator for anime diffusers https://github.com/KichangKim/DeepDanbooru
|
||||
- Sampling in float32 precision from a float16 UNet - marunine for the idea, Birch-san for the example Diffusers implementation (https://github.com/Birch-san/diffusers-play/tree/92feee6)
|
||||
- Instruct pix2pix - Tim Brooks (star), Aleksander Holynski (star), Alexei A. Efros (no star) - https://github.com/timothybrooks/instruct-pix2pix
|
||||
- Security advice - RyotaK
|
||||
- Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user.
|
||||
- (You)
|
||||
|
|
3041
artists.csv
Normal file
3041
artists.csv
Normal file
File diff suppressed because it is too large
Load Diff
|
@ -1,98 +0,0 @@
|
|||
# File modified by authors of InstructPix2Pix from original (https://github.com/CompVis/stable-diffusion).
|
||||
# See more details in LICENSE.
|
||||
|
||||
model:
|
||||
base_learning_rate: 1.0e-04
|
||||
target: modules.models.diffusion.ddpm_edit.LatentDiffusion
|
||||
params:
|
||||
linear_start: 0.00085
|
||||
linear_end: 0.0120
|
||||
num_timesteps_cond: 1
|
||||
log_every_t: 200
|
||||
timesteps: 1000
|
||||
first_stage_key: edited
|
||||
cond_stage_key: edit
|
||||
# image_size: 64
|
||||
# image_size: 32
|
||||
image_size: 16
|
||||
channels: 4
|
||||
cond_stage_trainable: false # Note: different from the one we trained before
|
||||
conditioning_key: hybrid
|
||||
monitor: val/loss_simple_ema
|
||||
scale_factor: 0.18215
|
||||
use_ema: false
|
||||
|
||||
scheduler_config: # 10000 warmup steps
|
||||
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||
params:
|
||||
warm_up_steps: [ 0 ]
|
||||
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||
f_start: [ 1.e-6 ]
|
||||
f_max: [ 1. ]
|
||||
f_min: [ 1. ]
|
||||
|
||||
unet_config:
|
||||
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||
params:
|
||||
image_size: 32 # unused
|
||||
in_channels: 8
|
||||
out_channels: 4
|
||||
model_channels: 320
|
||||
attention_resolutions: [ 4, 2, 1 ]
|
||||
num_res_blocks: 2
|
||||
channel_mult: [ 1, 2, 4, 4 ]
|
||||
num_heads: 8
|
||||
use_spatial_transformer: True
|
||||
transformer_depth: 1
|
||||
context_dim: 768
|
||||
use_checkpoint: True
|
||||
legacy: False
|
||||
|
||||
first_stage_config:
|
||||
target: ldm.models.autoencoder.AutoencoderKL
|
||||
params:
|
||||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
ddconfig:
|
||||
double_z: true
|
||||
z_channels: 4
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult:
|
||||
- 1
|
||||
- 2
|
||||
- 4
|
||||
- 4
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: []
|
||||
dropout: 0.0
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
|
||||
cond_stage_config:
|
||||
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||
|
||||
data:
|
||||
target: main.DataModuleFromConfig
|
||||
params:
|
||||
batch_size: 128
|
||||
num_workers: 1
|
||||
wrap: false
|
||||
validation:
|
||||
target: edit_dataset.EditDataset
|
||||
params:
|
||||
path: data/clip-filtered-dataset
|
||||
cache_dir: data/
|
||||
cache_name: data_10k
|
||||
split: val
|
||||
min_text_sim: 0.2
|
||||
min_image_sim: 0.75
|
||||
min_direction_sim: 0.2
|
||||
max_samples_per_prompt: 1
|
||||
min_resize_res: 512
|
||||
max_resize_res: 512
|
||||
crop_res: 512
|
||||
output_as_edit: False
|
||||
real_input: True
|
|
@ -1,6 +1,7 @@
|
|||
import os
|
||||
import gc
|
||||
import time
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -14,6 +15,8 @@ from ldm.models.diffusion.ddim import DDIMSampler
|
|||
from ldm.util import instantiate_from_config, ismap
|
||||
from modules import shared, sd_hijack
|
||||
|
||||
warnings.filterwarnings("ignore", category=UserWarning)
|
||||
|
||||
cached_ldsr_model: torch.nn.Module = None
|
||||
|
||||
|
||||
|
|
|
@ -1,26 +0,0 @@
|
|||
from modules import extra_networks, shared
|
||||
import lora
|
||||
|
||||
class ExtraNetworkLora(extra_networks.ExtraNetwork):
|
||||
def __init__(self):
|
||||
super().__init__('lora')
|
||||
|
||||
def activate(self, p, params_list):
|
||||
additional = shared.opts.sd_lora
|
||||
|
||||
if additional != "" and additional in lora.available_loras and len([x for x in params_list if x.items[0] == additional]) == 0:
|
||||
p.all_prompts = [x + f"<lora:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts]
|
||||
params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
|
||||
|
||||
names = []
|
||||
multipliers = []
|
||||
for params in params_list:
|
||||
assert len(params.items) > 0
|
||||
|
||||
names.append(params.items[0])
|
||||
multipliers.append(float(params.items[1]) if len(params.items) > 1 else 1.0)
|
||||
|
||||
lora.load_loras(names, multipliers)
|
||||
|
||||
def deactivate(self, p):
|
||||
pass
|
|
@ -1,207 +0,0 @@
|
|||
import glob
|
||||
import os
|
||||
import re
|
||||
import torch
|
||||
|
||||
from modules import shared, devices, sd_models
|
||||
|
||||
re_digits = re.compile(r"\d+")
|
||||
re_unet_down_blocks = re.compile(r"lora_unet_down_blocks_(\d+)_attentions_(\d+)_(.+)")
|
||||
re_unet_mid_blocks = re.compile(r"lora_unet_mid_block_attentions_(\d+)_(.+)")
|
||||
re_unet_up_blocks = re.compile(r"lora_unet_up_blocks_(\d+)_attentions_(\d+)_(.+)")
|
||||
re_text_block = re.compile(r"lora_te_text_model_encoder_layers_(\d+)_(.+)")
|
||||
|
||||
|
||||
def convert_diffusers_name_to_compvis(key):
|
||||
def match(match_list, regex):
|
||||
r = re.match(regex, key)
|
||||
if not r:
|
||||
return False
|
||||
|
||||
match_list.clear()
|
||||
match_list.extend([int(x) if re.match(re_digits, x) else x for x in r.groups()])
|
||||
return True
|
||||
|
||||
m = []
|
||||
|
||||
if match(m, re_unet_down_blocks):
|
||||
return f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[1]}_1_{m[2]}"
|
||||
|
||||
if match(m, re_unet_mid_blocks):
|
||||
return f"diffusion_model_middle_block_1_{m[1]}"
|
||||
|
||||
if match(m, re_unet_up_blocks):
|
||||
return f"diffusion_model_output_blocks_{m[0] * 3 + m[1]}_1_{m[2]}"
|
||||
|
||||
if match(m, re_text_block):
|
||||
return f"transformer_text_model_encoder_layers_{m[0]}_{m[1]}"
|
||||
|
||||
return key
|
||||
|
||||
|
||||
class LoraOnDisk:
|
||||
def __init__(self, name, filename):
|
||||
self.name = name
|
||||
self.filename = filename
|
||||
|
||||
|
||||
class LoraModule:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
self.multiplier = 1.0
|
||||
self.modules = {}
|
||||
self.mtime = None
|
||||
|
||||
|
||||
class LoraUpDownModule:
|
||||
def __init__(self):
|
||||
self.up = None
|
||||
self.down = None
|
||||
self.alpha = None
|
||||
|
||||
|
||||
def assign_lora_names_to_compvis_modules(sd_model):
|
||||
lora_layer_mapping = {}
|
||||
|
||||
for name, module in shared.sd_model.cond_stage_model.wrapped.named_modules():
|
||||
lora_name = name.replace(".", "_")
|
||||
lora_layer_mapping[lora_name] = module
|
||||
module.lora_layer_name = lora_name
|
||||
|
||||
for name, module in shared.sd_model.model.named_modules():
|
||||
lora_name = name.replace(".", "_")
|
||||
lora_layer_mapping[lora_name] = module
|
||||
module.lora_layer_name = lora_name
|
||||
|
||||
sd_model.lora_layer_mapping = lora_layer_mapping
|
||||
|
||||
|
||||
def load_lora(name, filename):
|
||||
lora = LoraModule(name)
|
||||
lora.mtime = os.path.getmtime(filename)
|
||||
|
||||
sd = sd_models.read_state_dict(filename)
|
||||
|
||||
keys_failed_to_match = []
|
||||
|
||||
for key_diffusers, weight in sd.items():
|
||||
fullkey = convert_diffusers_name_to_compvis(key_diffusers)
|
||||
key, lora_key = fullkey.split(".", 1)
|
||||
|
||||
sd_module = shared.sd_model.lora_layer_mapping.get(key, None)
|
||||
if sd_module is None:
|
||||
keys_failed_to_match.append(key_diffusers)
|
||||
continue
|
||||
|
||||
lora_module = lora.modules.get(key, None)
|
||||
if lora_module is None:
|
||||
lora_module = LoraUpDownModule()
|
||||
lora.modules[key] = lora_module
|
||||
|
||||
if lora_key == "alpha":
|
||||
lora_module.alpha = weight.item()
|
||||
continue
|
||||
|
||||
if type(sd_module) == torch.nn.Linear:
|
||||
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
|
||||
elif type(sd_module) == torch.nn.Conv2d:
|
||||
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
|
||||
else:
|
||||
assert False, f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}'
|
||||
|
||||
with torch.no_grad():
|
||||
module.weight.copy_(weight)
|
||||
|
||||
module.to(device=devices.device, dtype=devices.dtype)
|
||||
|
||||
if lora_key == "lora_up.weight":
|
||||
lora_module.up = module
|
||||
elif lora_key == "lora_down.weight":
|
||||
lora_module.down = module
|
||||
else:
|
||||
assert False, f'Bad Lora layer name: {key_diffusers} - must end in lora_up.weight, lora_down.weight or alpha'
|
||||
|
||||
if len(keys_failed_to_match) > 0:
|
||||
print(f"Failed to match keys when loading Lora {filename}: {keys_failed_to_match}")
|
||||
|
||||
return lora
|
||||
|
||||
|
||||
def load_loras(names, multipliers=None):
|
||||
already_loaded = {}
|
||||
|
||||
for lora in loaded_loras:
|
||||
if lora.name in names:
|
||||
already_loaded[lora.name] = lora
|
||||
|
||||
loaded_loras.clear()
|
||||
|
||||
loras_on_disk = [available_loras.get(name, None) for name in names]
|
||||
if any([x is None for x in loras_on_disk]):
|
||||
list_available_loras()
|
||||
|
||||
loras_on_disk = [available_loras.get(name, None) for name in names]
|
||||
|
||||
for i, name in enumerate(names):
|
||||
lora = already_loaded.get(name, None)
|
||||
|
||||
lora_on_disk = loras_on_disk[i]
|
||||
if lora_on_disk is not None:
|
||||
if lora is None or os.path.getmtime(lora_on_disk.filename) > lora.mtime:
|
||||
lora = load_lora(name, lora_on_disk.filename)
|
||||
|
||||
if lora is None:
|
||||
print(f"Couldn't find Lora with name {name}")
|
||||
continue
|
||||
|
||||
lora.multiplier = multipliers[i] if multipliers else 1.0
|
||||
loaded_loras.append(lora)
|
||||
|
||||
|
||||
def lora_forward(module, input, res):
|
||||
if len(loaded_loras) == 0:
|
||||
return res
|
||||
|
||||
lora_layer_name = getattr(module, 'lora_layer_name', None)
|
||||
for lora in loaded_loras:
|
||||
module = lora.modules.get(lora_layer_name, None)
|
||||
if module is not None:
|
||||
if shared.opts.lora_apply_to_outputs and res.shape == input.shape:
|
||||
res = res + module.up(module.down(res)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
|
||||
else:
|
||||
res = res + module.up(module.down(input)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def lora_Linear_forward(self, input):
|
||||
return lora_forward(self, input, torch.nn.Linear_forward_before_lora(self, input))
|
||||
|
||||
|
||||
def lora_Conv2d_forward(self, input):
|
||||
return lora_forward(self, input, torch.nn.Conv2d_forward_before_lora(self, input))
|
||||
|
||||
|
||||
def list_available_loras():
|
||||
available_loras.clear()
|
||||
|
||||
os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True)
|
||||
|
||||
candidates = \
|
||||
glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.pt'), recursive=True) + \
|
||||
glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.safetensors'), recursive=True) + \
|
||||
glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.ckpt'), recursive=True)
|
||||
|
||||
for filename in sorted(candidates):
|
||||
if os.path.isdir(filename):
|
||||
continue
|
||||
|
||||
name = os.path.splitext(os.path.basename(filename))[0]
|
||||
|
||||
available_loras[name] = LoraOnDisk(name, filename)
|
||||
|
||||
|
||||
available_loras = {}
|
||||
loaded_loras = []
|
||||
|
||||
list_available_loras()
|
|
@ -1,6 +0,0 @@
|
|||
import os
|
||||
from modules import paths
|
||||
|
||||
|
||||
def preload(parser):
|
||||
parser.add_argument("--lora-dir", type=str, help="Path to directory with Lora networks.", default=os.path.join(paths.models_path, 'Lora'))
|
|
@ -1,38 +0,0 @@
|
|||
import torch
|
||||
import gradio as gr
|
||||
|
||||
import lora
|
||||
import extra_networks_lora
|
||||
import ui_extra_networks_lora
|
||||
from modules import script_callbacks, ui_extra_networks, extra_networks, shared
|
||||
|
||||
|
||||
def unload():
|
||||
torch.nn.Linear.forward = torch.nn.Linear_forward_before_lora
|
||||
torch.nn.Conv2d.forward = torch.nn.Conv2d_forward_before_lora
|
||||
|
||||
|
||||
def before_ui():
|
||||
ui_extra_networks.register_page(ui_extra_networks_lora.ExtraNetworksPageLora())
|
||||
extra_networks.register_extra_network(extra_networks_lora.ExtraNetworkLora())
|
||||
|
||||
|
||||
if not hasattr(torch.nn, 'Linear_forward_before_lora'):
|
||||
torch.nn.Linear_forward_before_lora = torch.nn.Linear.forward
|
||||
|
||||
if not hasattr(torch.nn, 'Conv2d_forward_before_lora'):
|
||||
torch.nn.Conv2d_forward_before_lora = torch.nn.Conv2d.forward
|
||||
|
||||
torch.nn.Linear.forward = lora.lora_Linear_forward
|
||||
torch.nn.Conv2d.forward = lora.lora_Conv2d_forward
|
||||
|
||||
script_callbacks.on_model_loaded(lora.assign_lora_names_to_compvis_modules)
|
||||
script_callbacks.on_script_unloaded(unload)
|
||||
script_callbacks.on_before_ui(before_ui)
|
||||
|
||||
|
||||
shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), {
|
||||
"sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": [""] + [x for x in lora.available_loras]}, refresh=lora.list_available_loras),
|
||||
"lora_apply_to_outputs": shared.OptionInfo(False, "Apply Lora to outputs rather than inputs when possible (experimental)"),
|
||||
|
||||
}))
|
|
@ -1,37 +0,0 @@
|
|||
import json
|
||||
import os
|
||||
import lora
|
||||
|
||||
from modules import shared, ui_extra_networks
|
||||
|
||||
|
||||
class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
|
||||
def __init__(self):
|
||||
super().__init__('Lora')
|
||||
|
||||
def refresh(self):
|
||||
lora.list_available_loras()
|
||||
|
||||
def list_items(self):
|
||||
for name, lora_on_disk in lora.available_loras.items():
|
||||
path, ext = os.path.splitext(lora_on_disk.filename)
|
||||
previews = [path + ".png", path + ".preview.png"]
|
||||
|
||||
preview = None
|
||||
for file in previews:
|
||||
if os.path.isfile(file):
|
||||
preview = self.link_preview(file)
|
||||
break
|
||||
|
||||
yield {
|
||||
"name": name,
|
||||
"filename": path,
|
||||
"preview": preview,
|
||||
"search_term": self.search_terms_from_path(lora_on_disk.filename),
|
||||
"prompt": json.dumps(f"<lora:{name}:") + " + opts.extra_networks_default_multiplier + " + json.dumps(">"),
|
||||
"local_preview": path + ".png",
|
||||
}
|
||||
|
||||
def allowed_directories_for_previews(self):
|
||||
return [shared.cmd_opts.lora_dir]
|
||||
|
|
@ -8,7 +8,7 @@ from basicsr.utils.download_util import load_file_from_url
|
|||
from tqdm import tqdm
|
||||
|
||||
from modules import modelloader, devices, script_callbacks, shared
|
||||
from modules.shared import cmd_opts, opts, state
|
||||
from modules.shared import cmd_opts, opts
|
||||
from swinir_model_arch import SwinIR as net
|
||||
from swinir_model_arch_v2 import Swin2SR as net2
|
||||
from modules.upscaler import Upscaler, UpscalerData
|
||||
|
@ -145,13 +145,7 @@ def inference(img, model, tile, tile_overlap, window_size, scale):
|
|||
|
||||
with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar:
|
||||
for h_idx in h_idx_list:
|
||||
if state.interrupted or state.skipped:
|
||||
break
|
||||
|
||||
for w_idx in w_idx_list:
|
||||
if state.interrupted or state.skipped:
|
||||
break
|
||||
|
||||
in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile]
|
||||
out_patch = model(in_patch)
|
||||
out_patch_mask = torch.ones_like(out_patch)
|
||||
|
|
|
@ -4,10 +4,16 @@
|
|||
// Counts open and closed brackets (round, square, curly) in the prompt and negative prompt text boxes in the txt2img and img2img tabs.
|
||||
// If there's a mismatch, the keyword counter turns red and if you hover on it, a tooltip tells you what's wrong.
|
||||
|
||||
function checkBrackets(evt, textArea, counterElt) {
|
||||
errorStringParen = '(...) - Different number of opening and closing parentheses detected.\n';
|
||||
errorStringSquare = '[...] - Different number of opening and closing square brackets detected.\n';
|
||||
errorStringCurly = '{...} - Different number of opening and closing curly brackets detected.\n';
|
||||
function checkBrackets(evt) {
|
||||
textArea = evt.target;
|
||||
tabName = evt.target.parentElement.parentElement.id.split("_")[0];
|
||||
counterElt = document.querySelector('gradio-app').shadowRoot.querySelector('#' + tabName + '_token_counter');
|
||||
|
||||
promptName = evt.target.parentElement.parentElement.id.includes('neg') ? ' negative' : '';
|
||||
|
||||
errorStringParen = '(' + tabName + promptName + ' prompt) - Different number of opening and closing parentheses detected.\n';
|
||||
errorStringSquare = '[' + tabName + promptName + ' prompt] - Different number of opening and closing square brackets detected.\n';
|
||||
errorStringCurly = '{' + tabName + promptName + ' prompt} - Different number of opening and closing curly brackets detected.\n';
|
||||
|
||||
openBracketRegExp = /\(/g;
|
||||
closeBracketRegExp = /\)/g;
|
||||
|
@ -80,31 +86,22 @@ function checkBrackets(evt, textArea, counterElt) {
|
|||
}
|
||||
|
||||
if(counterElt.title != '') {
|
||||
counterElt.classList.add('error');
|
||||
counterElt.style = 'color: #FF5555;';
|
||||
} else {
|
||||
counterElt.classList.remove('error');
|
||||
counterElt.style = '';
|
||||
}
|
||||
}
|
||||
|
||||
function setupBracketChecking(id_prompt, id_counter){
|
||||
var textarea = gradioApp().querySelector("#" + id_prompt + " > label > textarea");
|
||||
var counter = gradioApp().getElementById(id_counter)
|
||||
textarea.addEventListener("input", function(evt){
|
||||
checkBrackets(evt, textarea, counter)
|
||||
});
|
||||
}
|
||||
|
||||
var shadowRootLoaded = setInterval(function() {
|
||||
var shadowRoot = document.querySelector('gradio-app').shadowRoot;
|
||||
if(! shadowRoot) return false;
|
||||
var shadowTextArea = document.querySelector('gradio-app').shadowRoot.querySelectorAll('#txt2img_prompt > label > textarea');
|
||||
if(shadowTextArea.length < 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
var shadowTextArea = shadowRoot.querySelectorAll('#txt2img_prompt > label > textarea');
|
||||
if(shadowTextArea.length < 1) return false;
|
||||
clearInterval(shadowRootLoaded);
|
||||
|
||||
clearInterval(shadowRootLoaded);
|
||||
|
||||
setupBracketChecking('txt2img_prompt', 'txt2img_token_counter')
|
||||
setupBracketChecking('txt2img_neg_prompt', 'txt2img_negative_token_counter')
|
||||
setupBracketChecking('img2img_prompt', 'imgimg_token_counter')
|
||||
setupBracketChecking('img2img_neg_prompt', 'img2img_negative_token_counter')
|
||||
document.querySelector('gradio-app').shadowRoot.querySelector('#txt2img_prompt').onkeyup = checkBrackets;
|
||||
document.querySelector('gradio-app').shadowRoot.querySelector('#txt2img_neg_prompt').onkeyup = checkBrackets;
|
||||
document.querySelector('gradio-app').shadowRoot.querySelector('#img2img_prompt').onkeyup = checkBrackets;
|
||||
document.querySelector('gradio-app').shadowRoot.querySelector('#img2img_neg_prompt').onkeyup = checkBrackets;
|
||||
}, 1000);
|
||||
|
|
50
extensions-builtin/roll-artist/scripts/roll-artist.py
Normal file
50
extensions-builtin/roll-artist/scripts/roll-artist.py
Normal file
|
@ -0,0 +1,50 @@
|
|||
import random
|
||||
|
||||
from modules import script_callbacks, shared
|
||||
import gradio as gr
|
||||
|
||||
art_symbol = '\U0001f3a8' # 🎨
|
||||
global_prompt = None
|
||||
related_ids = {"txt2img_prompt", "txt2img_clear_prompt", "img2img_prompt", "img2img_clear_prompt" }
|
||||
|
||||
|
||||
def roll_artist(prompt):
|
||||
allowed_cats = set([x for x in shared.artist_db.categories() if len(shared.opts.random_artist_categories)==0 or x in shared.opts.random_artist_categories])
|
||||
artist = random.choice([x for x in shared.artist_db.artists if x.category in allowed_cats])
|
||||
|
||||
return prompt + ", " + artist.name if prompt != '' else artist.name
|
||||
|
||||
|
||||
def add_roll_button(prompt):
|
||||
roll = gr.Button(value=art_symbol, elem_id="roll", visible=len(shared.artist_db.artists) > 0)
|
||||
|
||||
roll.click(
|
||||
fn=roll_artist,
|
||||
_js="update_txt2img_tokens",
|
||||
inputs=[
|
||||
prompt,
|
||||
],
|
||||
outputs=[
|
||||
prompt,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def after_component(component, **kwargs):
|
||||
global global_prompt
|
||||
|
||||
elem_id = kwargs.get('elem_id', None)
|
||||
if elem_id not in related_ids:
|
||||
return
|
||||
|
||||
if elem_id == "txt2img_prompt":
|
||||
global_prompt = component
|
||||
elif elem_id == "txt2img_clear_prompt":
|
||||
add_roll_button(global_prompt)
|
||||
elif elem_id == "img2img_prompt":
|
||||
global_prompt = component
|
||||
elif elem_id == "img2img_clear_prompt":
|
||||
add_roll_button(global_prompt)
|
||||
|
||||
|
||||
script_callbacks.on_after_component(after_component)
|
Binary file not shown.
Before ![]() (image error) Size: 82 KiB |
|
@ -1,12 +0,0 @@
|
|||
<div class='card' {preview_html} onclick={card_clicked}>
|
||||
<div class='actions'>
|
||||
<div class='additional'>
|
||||
<ul>
|
||||
<a href="#" title="replace preview image with currently selected in gallery" onclick={save_card_preview}>replace preview</a>
|
||||
</ul>
|
||||
<span style="display:none" class='search_term'>{search_term}</span>
|
||||
</div>
|
||||
<span class='name'>{name}</span>
|
||||
</div>
|
||||
</div>
|
||||
|
|
@ -1,8 +0,0 @@
|
|||
<div class='nocards'>
|
||||
<h1>Nothing here. Add some content to the following directories:</h1>
|
||||
|
||||
<ul>
|
||||
{dirs}
|
||||
</ul>
|
||||
</div>
|
||||
|
|
@ -1,7 +0,0 @@
|
|||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24">
|
||||
<filter id='shadow' color-interpolation-filters="sRGB">
|
||||
<feDropShadow flood-color="black" dx="0" dy="0" flood-opacity="0.9" stdDeviation="0.5"/>
|
||||
<feDropShadow flood-color="black" dx="0" dy="0" flood-opacity="0.9" stdDeviation="0.5"/>
|
||||
</filter>
|
||||
<path style="filter:url(#shadow);" fill="#FFFFFF" d="M13.18 19C13.35 19.72 13.64 20.39 14.03 21H5C3.9 21 3 20.11 3 19V5C3 3.9 3.9 3 5 3H19C20.11 3 21 3.9 21 5V11.18C20.5 11.07 20 11 19.5 11C19.33 11 19.17 11 19 11.03V5H5V19H13.18M11.21 15.83L9.25 13.47L6.5 17H13.03C13.14 15.54 13.73 14.22 14.64 13.19L13.96 12.29L11.21 15.83M19 13.5V12L16.75 14.25L19 16.5V15C20.38 15 21.5 16.12 21.5 17.5C21.5 17.9 21.41 18.28 21.24 18.62L22.33 19.71C22.75 19.08 23 18.32 23 17.5C23 15.29 21.21 13.5 19 13.5M19 20C17.62 20 16.5 18.88 16.5 17.5C16.5 17.1 16.59 16.72 16.76 16.38L15.67 15.29C15.25 15.92 15 16.68 15 17.5C15 19.71 16.79 21.5 19 21.5V23L21.25 20.75L19 18.5V20Z" />
|
||||
</svg>
|
Before (image error) Size: 989 B |
|
@ -1,96 +1,75 @@
|
|||
function keyupEditAttention(event){
|
||||
addEventListener('keydown', (event) => {
|
||||
let target = event.originalTarget || event.composedPath()[0];
|
||||
if (!target.matches("[id*='_toprow'] textarea.gr-text-input[placeholder]")) return;
|
||||
if (!target.matches("#toprow textarea.gr-text-input[placeholder]")) return;
|
||||
if (! (event.metaKey || event.ctrlKey)) return;
|
||||
|
||||
let isPlus = event.key == "ArrowUp"
|
||||
let isMinus = event.key == "ArrowDown"
|
||||
if (!isPlus && !isMinus) return;
|
||||
|
||||
let plus = "ArrowUp"
|
||||
let minus = "ArrowDown"
|
||||
if (event.key != plus && event.key != minus) return;
|
||||
|
||||
let selectionStart = target.selectionStart;
|
||||
let selectionEnd = target.selectionEnd;
|
||||
let text = target.value;
|
||||
|
||||
function selectCurrentParenthesisBlock(OPEN, CLOSE){
|
||||
if (selectionStart !== selectionEnd) return false;
|
||||
|
||||
// If the user hasn't selected anything, let's select their current parenthesis block
|
||||
if (selectionStart === selectionEnd) {
|
||||
// Find opening parenthesis around current cursor
|
||||
const before = text.substring(0, selectionStart);
|
||||
let beforeParen = before.lastIndexOf(OPEN);
|
||||
if (beforeParen == -1) return false;
|
||||
let beforeParenClose = before.lastIndexOf(CLOSE);
|
||||
const before = target.value.substring(0, selectionStart);
|
||||
let beforeParen = before.lastIndexOf("(");
|
||||
if (beforeParen == -1) return;
|
||||
let beforeParenClose = before.lastIndexOf(")");
|
||||
while (beforeParenClose !== -1 && beforeParenClose > beforeParen) {
|
||||
beforeParen = before.lastIndexOf(OPEN, beforeParen - 1);
|
||||
beforeParenClose = before.lastIndexOf(CLOSE, beforeParenClose - 1);
|
||||
beforeParen = before.lastIndexOf("(", beforeParen - 1);
|
||||
beforeParenClose = before.lastIndexOf(")", beforeParenClose - 1);
|
||||
}
|
||||
|
||||
// Find closing parenthesis around current cursor
|
||||
const after = text.substring(selectionStart);
|
||||
let afterParen = after.indexOf(CLOSE);
|
||||
if (afterParen == -1) return false;
|
||||
let afterParenOpen = after.indexOf(OPEN);
|
||||
const after = target.value.substring(selectionStart);
|
||||
let afterParen = after.indexOf(")");
|
||||
if (afterParen == -1) return;
|
||||
let afterParenOpen = after.indexOf("(");
|
||||
while (afterParenOpen !== -1 && afterParen > afterParenOpen) {
|
||||
afterParen = after.indexOf(CLOSE, afterParen + 1);
|
||||
afterParenOpen = after.indexOf(OPEN, afterParenOpen + 1);
|
||||
afterParen = after.indexOf(")", afterParen + 1);
|
||||
afterParenOpen = after.indexOf("(", afterParenOpen + 1);
|
||||
}
|
||||
if (beforeParen === -1 || afterParen === -1) return false;
|
||||
if (beforeParen === -1 || afterParen === -1) return;
|
||||
|
||||
// Set the selection to the text between the parenthesis
|
||||
const parenContent = text.substring(beforeParen + 1, selectionStart + afterParen);
|
||||
const parenContent = target.value.substring(beforeParen + 1, selectionStart + afterParen);
|
||||
const lastColon = parenContent.lastIndexOf(":");
|
||||
selectionStart = beforeParen + 1;
|
||||
selectionEnd = selectionStart + lastColon;
|
||||
target.setSelectionRange(selectionStart, selectionEnd);
|
||||
return true;
|
||||
}
|
||||
|
||||
// If the user hasn't selected anything, let's select their current parenthesis block
|
||||
if(! selectCurrentParenthesisBlock('<', '>')){
|
||||
selectCurrentParenthesisBlock('(', ')')
|
||||
}
|
||||
}
|
||||
|
||||
event.preventDefault();
|
||||
|
||||
closeCharacter = ')'
|
||||
delta = opts.keyedit_precision_attention
|
||||
if (selectionStart == 0 || target.value[selectionStart - 1] != "(") {
|
||||
target.value = target.value.slice(0, selectionStart) +
|
||||
"(" + target.value.slice(selectionStart, selectionEnd) + ":1.0)" +
|
||||
target.value.slice(selectionEnd);
|
||||
|
||||
if (selectionStart > 0 && text[selectionStart - 1] == '<'){
|
||||
closeCharacter = '>'
|
||||
delta = opts.keyedit_precision_extra
|
||||
} else if (selectionStart == 0 || text[selectionStart - 1] != "(") {
|
||||
target.focus();
|
||||
target.selectionStart = selectionStart + 1;
|
||||
target.selectionEnd = selectionEnd + 1;
|
||||
|
||||
// do not include spaces at the end
|
||||
while(selectionEnd > selectionStart && text[selectionEnd-1] == ' '){
|
||||
selectionEnd -= 1;
|
||||
}
|
||||
if(selectionStart == selectionEnd){
|
||||
return
|
||||
}
|
||||
} else {
|
||||
end = target.value.slice(selectionEnd + 1).indexOf(")") + 1;
|
||||
weight = parseFloat(target.value.slice(selectionEnd + 1, selectionEnd + 1 + end));
|
||||
if (isNaN(weight)) return;
|
||||
if (event.key == minus) weight -= 0.1;
|
||||
if (event.key == plus) weight += 0.1;
|
||||
|
||||
text = text.slice(0, selectionStart) + "(" + text.slice(selectionStart, selectionEnd) + ":1.0)" + text.slice(selectionEnd);
|
||||
weight = parseFloat(weight.toPrecision(12));
|
||||
|
||||
selectionStart += 1;
|
||||
selectionEnd += 1;
|
||||
}
|
||||
target.value = target.value.slice(0, selectionEnd + 1) +
|
||||
weight +
|
||||
target.value.slice(selectionEnd + 1 + end - 1);
|
||||
|
||||
end = text.slice(selectionEnd + 1).indexOf(closeCharacter) + 1;
|
||||
weight = parseFloat(text.slice(selectionEnd + 1, selectionEnd + 1 + end));
|
||||
if (isNaN(weight)) return;
|
||||
|
||||
weight += isPlus ? delta : -delta;
|
||||
weight = parseFloat(weight.toPrecision(12));
|
||||
if(String(weight).length == 1) weight += ".0"
|
||||
|
||||
text = text.slice(0, selectionEnd + 1) + weight + text.slice(selectionEnd + 1 + end - 1);
|
||||
|
||||
target.focus();
|
||||
target.value = text;
|
||||
target.selectionStart = selectionStart;
|
||||
target.selectionEnd = selectionEnd;
|
||||
|
||||
updateInput(target)
|
||||
}
|
||||
|
||||
addEventListener('keydown', (event) => {
|
||||
keyupEditAttention(event);
|
||||
});
|
||||
target.focus();
|
||||
target.selectionStart = selectionStart;
|
||||
target.selectionEnd = selectionEnd;
|
||||
}
|
||||
// Since we've modified a Gradio Textbox component manually, we need to simulate an `input` DOM event to ensure its
|
||||
// internal Svelte data binding remains in sync.
|
||||
target.dispatchEvent(new Event("input", { bubbles: true }));
|
||||
});
|
||||
|
|
|
@ -1,8 +1,7 @@
|
|||
|
||||
function extensions_apply(_, _){
|
||||
var disable = []
|
||||
var update = []
|
||||
|
||||
disable = []
|
||||
update = []
|
||||
gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){
|
||||
if(x.name.startsWith("enable_") && ! x.checked)
|
||||
disable.push(x.name.substr(7))
|
||||
|
@ -17,24 +16,11 @@ function extensions_apply(_, _){
|
|||
}
|
||||
|
||||
function extensions_check(){
|
||||
var disable = []
|
||||
|
||||
gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){
|
||||
if(x.name.startsWith("enable_") && ! x.checked)
|
||||
disable.push(x.name.substr(7))
|
||||
})
|
||||
|
||||
gradioApp().querySelectorAll('#extensions .extension_status').forEach(function(x){
|
||||
x.innerHTML = "Loading..."
|
||||
})
|
||||
|
||||
|
||||
var id = randomId()
|
||||
requestProgress(id, gradioApp().getElementById('extensions_installed_top'), null, function(){
|
||||
|
||||
})
|
||||
|
||||
return [id, JSON.stringify(disable)]
|
||||
return []
|
||||
}
|
||||
|
||||
function install_extension_from_index(button, url){
|
||||
|
@ -43,7 +29,7 @@ function install_extension_from_index(button, url){
|
|||
|
||||
textarea = gradioApp().querySelector('#extension_to_install textarea')
|
||||
textarea.value = url
|
||||
updateInput(textarea)
|
||||
textarea.dispatchEvent(new Event("input", { bubbles: true }))
|
||||
|
||||
gradioApp().querySelector('#install_extension_button').click()
|
||||
}
|
||||
|
|
|
@ -1,107 +0,0 @@
|
|||
|
||||
function setupExtraNetworksForTab(tabname){
|
||||
gradioApp().querySelector('#'+tabname+'_extra_tabs').classList.add('extra-networks')
|
||||
|
||||
var tabs = gradioApp().querySelector('#'+tabname+'_extra_tabs > div')
|
||||
var search = gradioApp().querySelector('#'+tabname+'_extra_search textarea')
|
||||
var refresh = gradioApp().getElementById(tabname+'_extra_refresh')
|
||||
var close = gradioApp().getElementById(tabname+'_extra_close')
|
||||
|
||||
search.classList.add('search')
|
||||
tabs.appendChild(search)
|
||||
tabs.appendChild(refresh)
|
||||
tabs.appendChild(close)
|
||||
|
||||
search.addEventListener("input", function(evt){
|
||||
searchTerm = search.value.toLowerCase()
|
||||
|
||||
gradioApp().querySelectorAll('#'+tabname+'_extra_tabs div.card').forEach(function(elem){
|
||||
text = elem.querySelector('.name').textContent.toLowerCase() + " " + elem.querySelector('.search_term').textContent.toLowerCase()
|
||||
elem.style.display = text.indexOf(searchTerm) == -1 ? "none" : ""
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
var activePromptTextarea = {};
|
||||
|
||||
function setupExtraNetworks(){
|
||||
setupExtraNetworksForTab('txt2img')
|
||||
setupExtraNetworksForTab('img2img')
|
||||
|
||||
function registerPrompt(tabname, id){
|
||||
var textarea = gradioApp().querySelector("#" + id + " > label > textarea");
|
||||
|
||||
if (! activePromptTextarea[tabname]){
|
||||
activePromptTextarea[tabname] = textarea
|
||||
}
|
||||
|
||||
textarea.addEventListener("focus", function(){
|
||||
activePromptTextarea[tabname] = textarea;
|
||||
});
|
||||
}
|
||||
|
||||
registerPrompt('txt2img', 'txt2img_prompt')
|
||||
registerPrompt('txt2img', 'txt2img_neg_prompt')
|
||||
registerPrompt('img2img', 'img2img_prompt')
|
||||
registerPrompt('img2img', 'img2img_neg_prompt')
|
||||
}
|
||||
|
||||
onUiLoaded(setupExtraNetworks)
|
||||
|
||||
var re_extranet = /<([^:]+:[^:]+):[\d\.]+>/;
|
||||
var re_extranet_g = /\s+<([^:]+:[^:]+):[\d\.]+>/g;
|
||||
|
||||
function tryToRemoveExtraNetworkFromPrompt(textarea, text){
|
||||
var m = text.match(re_extranet)
|
||||
if(! m) return false
|
||||
|
||||
var partToSearch = m[1]
|
||||
var replaced = false
|
||||
var newTextareaText = textarea.value.replaceAll(re_extranet_g, function(found, index){
|
||||
m = found.match(re_extranet);
|
||||
if(m[1] == partToSearch){
|
||||
replaced = true;
|
||||
return ""
|
||||
}
|
||||
return found;
|
||||
})
|
||||
|
||||
if(replaced){
|
||||
textarea.value = newTextareaText
|
||||
return true;
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
function cardClicked(tabname, textToAdd, allowNegativePrompt){
|
||||
var textarea = allowNegativePrompt ? activePromptTextarea[tabname] : gradioApp().querySelector("#" + tabname + "_prompt > label > textarea")
|
||||
|
||||
if(! tryToRemoveExtraNetworkFromPrompt(textarea, textToAdd)){
|
||||
textarea.value = textarea.value + " " + textToAdd
|
||||
}
|
||||
|
||||
updateInput(textarea)
|
||||
}
|
||||
|
||||
function saveCardPreview(event, tabname, filename){
|
||||
var textarea = gradioApp().querySelector("#" + tabname + '_preview_filename > label > textarea')
|
||||
var button = gradioApp().getElementById(tabname + '_save_preview')
|
||||
|
||||
textarea.value = filename
|
||||
updateInput(textarea)
|
||||
|
||||
button.click()
|
||||
|
||||
event.stopPropagation()
|
||||
event.preventDefault()
|
||||
}
|
||||
|
||||
function extraNetworksSearchButton(tabs_id, event){
|
||||
searchTextarea = gradioApp().querySelector("#" + tabs_id + ' > div > textarea')
|
||||
button = event.target
|
||||
text = button.classList.contains("search-all") ? "" : button.textContent.trim()
|
||||
|
||||
searchTextarea.value = text
|
||||
updateInput(searchTextarea)
|
||||
}
|
|
@ -14,14 +14,12 @@ titles = {
|
|||
"Seed": "A value that determines the output of random number generator - if you create an image with same parameters and seed as another image, you'll get the same result",
|
||||
"\u{1f3b2}\ufe0f": "Set seed to -1, which will cause a new random number to be used every time",
|
||||
"\u267b\ufe0f": "Reuse seed from last generation, mostly useful if it was randomed",
|
||||
"\u{1f3a8}": "Add a random artist to the prompt.",
|
||||
"\u2199\ufe0f": "Read generation parameters from prompt or last generation if prompt is empty into user interface.",
|
||||
"\u{1f4c2}": "Open images output directory",
|
||||
"\u{1f4be}": "Save style",
|
||||
"\u{1f5d1}": "Clear prompt",
|
||||
"\U0001F5D1": "Clear prompt",
|
||||
"\u{1f4cb}": "Apply selected styles to current prompt",
|
||||
"\u{1f4d2}": "Paste available values into the field",
|
||||
"\u{1f3b4}": "Show extra networks",
|
||||
|
||||
|
||||
"Inpaint a part of image": "Draw a mask over an image, and the script will regenerate the masked area with content according to prompt",
|
||||
"SD upscale": "Upscale image normally, split result into tiles, improve each tile using img2img, merge whole image back",
|
||||
|
@ -50,7 +48,7 @@ titles = {
|
|||
|
||||
"None": "Do not do anything special",
|
||||
"Prompt matrix": "Separate prompts into parts using vertical pipe character (|) and the script will create a picture for every combination of them (except for the first part, which will be present in all combinations)",
|
||||
"X/Y/Z plot": "Create grid(s) where images will have different parameters. Use inputs below to specify which parameters will be shared by columns and rows",
|
||||
"X/Y plot": "Create a grid where images will have different parameters. Use inputs below to specify which parameters will be shared by columns and rows",
|
||||
"Custom code": "Run Python code. Advanced user only. Must run program with --allow-code for this to work",
|
||||
|
||||
"Prompt S/R": "Separate a list of words with commas, and the first word will be used as a keyword: script will search for this word in the prompt, and replace it with others",
|
||||
|
@ -66,8 +64,8 @@ titles = {
|
|||
|
||||
"Interrogate": "Reconstruct prompt from existing image and put it into the prompt field.",
|
||||
|
||||
"Images filename pattern": "Use following tags to define how filenames for images are chosen: [steps], [cfg], [prompt_hash], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [model_name], [prompt_words], [date], [datetime], [datetime<Format>], [datetime<Format><Time Zone>], [job_timestamp]; leave empty for default.",
|
||||
"Directory name pattern": "Use following tags to define how subdirectories for images and grids are chosen: [steps], [cfg],[prompt_hash], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [model_name], [prompt_words], [date], [datetime], [datetime<Format>], [datetime<Format><Time Zone>], [job_timestamp]; leave empty for default.",
|
||||
"Images filename pattern": "Use following tags to define how filenames for images are chosen: [steps], [cfg], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [model_name], [prompt_words], [date], [datetime], [datetime<Format>], [datetime<Format><Time Zone>], [job_timestamp]; leave empty for default.",
|
||||
"Directory name pattern": "Use following tags to define how subdirectories for images and grids are chosen: [steps], [cfg], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [model_name], [prompt_words], [date], [datetime], [datetime<Format>], [datetime<Format><Time Zone>], [job_timestamp]; leave empty for default.",
|
||||
"Max prompt words": "Set the maximum number of words to be used in the [prompt_words] option; ATTENTION: If the words are too long, they may exceed the maximum length of the file path that the system can handle",
|
||||
|
||||
"Loopback": "Process an image, use it as an input, repeat.",
|
||||
|
@ -93,7 +91,6 @@ titles = {
|
|||
|
||||
"Weighted sum": "Result = A * (1 - M) + B * M",
|
||||
"Add difference": "Result = A + (B - C) * M",
|
||||
"No interpolation": "Result = A",
|
||||
|
||||
"Initialization text": "If the number of tokens is more than the number of vectors, some may be skipped.\nLeave the textbox empty to start with zeroed out vectors",
|
||||
"Learning rate": "How fast should training go. Low values will take longer to train, high values may fail to converge (not generate accurate results) and/or may break the embedding (This has happened if you see Loss: nan in the training info textbox. If this happens, you need to manually restore your embedding from an older not-broken backup).\n\nYou can set a single numeric value, or multiple learning rates using the syntax:\n\n rate_1:max_steps_1, rate_2:max_steps_2, ...\n\nEG: 0.005:100, 1e-3:1000, 1e-5\n\nWill train with rate of 0.005 for first 100 steps, then 1e-3 until 1000 steps, then 1e-5 for all remaining steps.",
|
||||
|
@ -107,10 +104,7 @@ titles = {
|
|||
"Hires steps": "Number of sampling steps for upscaled picture. If 0, uses same as for original.",
|
||||
"Upscale by": "Adjusts the size of the image by multiplying the original width and height by the selected value. Ignored if either Resize width to or Resize height to are non-zero.",
|
||||
"Resize width to": "Resizes image to this width. If 0, width is inferred from either of two nearby sliders.",
|
||||
"Resize height to": "Resizes image to this height. If 0, height is inferred from either of two nearby sliders.",
|
||||
"Multiplier for extra networks": "When adding extra network such as Hypernetwork or Lora to prompt, use this multiplier for it.",
|
||||
"Discard weights with matching name": "Regular expression; if weights's name matches it, the weights is not written to the resulting checkpoint. Use ^model_ema to discard EMA weights.",
|
||||
"Extra networks tab order": "Comma-separated list of tab names; tabs listed here will appear in the extra networks UI first and in order lsited."
|
||||
"Resize height to": "Resizes image to this height. If 0, height is inferred from either of two nearby sliders."
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
|
||||
function setInactive(elem, inactive){
|
||||
console.log(elem)
|
||||
if(inactive){
|
||||
elem.classList.add('inactive')
|
||||
} else{
|
||||
|
@ -8,6 +9,8 @@ function setInactive(elem, inactive){
|
|||
}
|
||||
|
||||
function onCalcResolutionHires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y){
|
||||
console.log(enable, width, height, hr_scale, hr_resize_x, hr_resize_y)
|
||||
|
||||
hrUpscaleBy = gradioApp().getElementById('txt2img_hr_scale')
|
||||
hrResizeX = gradioApp().getElementById('txt2img_hr_resize_x')
|
||||
hrResizeY = gradioApp().getElementById('txt2img_hr_resize_y')
|
||||
|
|
|
@ -148,15 +148,7 @@ function showGalleryImage() {
|
|||
if(e && e.parentElement.tagName == 'DIV'){
|
||||
e.style.cursor='pointer'
|
||||
e.style.userSelect='none'
|
||||
|
||||
var isFirefox = isFirefox = navigator.userAgent.toLowerCase().indexOf('firefox') > -1
|
||||
|
||||
// For Firefox, listening on click first switched to next image then shows the lightbox.
|
||||
// If you know how to fix this without switching to mousedown event, please.
|
||||
// For other browsers the event is click to make it possiblr to drag picture.
|
||||
var event = isFirefox ? 'mousedown' : 'click'
|
||||
|
||||
e.addEventListener(event, function (evt) {
|
||||
e.addEventListener('mousedown', function (evt) {
|
||||
if(!opts.js_modal_lightbox || evt.button != 0) return;
|
||||
modalZoomSet(gradioApp().getElementById('modalImage'), opts.js_modal_lightbox_initially_zoomed)
|
||||
evt.preventDefault()
|
||||
|
|
|
@ -10,8 +10,10 @@ ignore_ids_for_localization={
|
|||
modelmerger_tertiary_model_name: 'OPTION',
|
||||
train_embedding: 'OPTION',
|
||||
train_hypernetwork: 'OPTION',
|
||||
txt2img_styles: 'OPTION',
|
||||
img2img_styles: 'OPTION',
|
||||
txt2img_style_index: 'OPTION',
|
||||
txt2img_style2_index: 'OPTION',
|
||||
img2img_style_index: 'OPTION',
|
||||
img2img_style2_index: 'OPTION',
|
||||
setting_random_artist_categories: 'SPAN',
|
||||
setting_face_restoration_model: 'SPAN',
|
||||
setting_realesrgan_enabled_models: 'SPAN',
|
||||
|
|
|
@ -81,13 +81,8 @@ function request(url, data, handler, errorHandler){
|
|||
xhr.onreadystatechange = function () {
|
||||
if (xhr.readyState === 4) {
|
||||
if (xhr.status === 200) {
|
||||
try {
|
||||
var js = JSON.parse(xhr.responseText);
|
||||
handler(js)
|
||||
} catch (error) {
|
||||
console.error(error);
|
||||
errorHandler()
|
||||
}
|
||||
var js = JSON.parse(xhr.responseText);
|
||||
handler(js)
|
||||
} else{
|
||||
errorHandler()
|
||||
}
|
||||
|
@ -111,19 +106,6 @@ function formatTime(secs){
|
|||
}
|
||||
}
|
||||
|
||||
function setTitle(progress){
|
||||
var title = 'Stable Diffusion'
|
||||
|
||||
if(opts.show_progress_in_title && progress){
|
||||
title = '[' + progress.trim() + '] ' + title;
|
||||
}
|
||||
|
||||
if(document.title != title){
|
||||
document.title = title;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
function randomId(){
|
||||
return "task(" + Math.random().toString(36).slice(2, 7) + Math.random().toString(36).slice(2, 7) + Math.random().toString(36).slice(2, 7)+")"
|
||||
}
|
||||
|
@ -135,32 +117,30 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre
|
|||
var dateStart = new Date()
|
||||
var wasEverActive = false
|
||||
var parentProgressbar = progressbarContainer.parentNode
|
||||
var parentGallery = gallery ? gallery.parentNode : null
|
||||
var parentGallery = gallery.parentNode
|
||||
|
||||
var divProgress = document.createElement('div')
|
||||
divProgress.className='progressDiv'
|
||||
divProgress.style.display = opts.show_progressbar ? "" : "none"
|
||||
var divInner = document.createElement('div')
|
||||
divInner.className='progress'
|
||||
|
||||
divProgress.appendChild(divInner)
|
||||
parentProgressbar.insertBefore(divProgress, progressbarContainer)
|
||||
|
||||
if(parentGallery){
|
||||
var livePreview = document.createElement('div')
|
||||
livePreview.className='livePreview'
|
||||
parentGallery.insertBefore(livePreview, gallery)
|
||||
}
|
||||
var livePreview = document.createElement('div')
|
||||
livePreview.className='livePreview'
|
||||
parentGallery.insertBefore(livePreview, gallery)
|
||||
|
||||
var removeProgressBar = function(){
|
||||
setTitle("")
|
||||
parentProgressbar.removeChild(divProgress)
|
||||
if(parentGallery) parentGallery.removeChild(livePreview)
|
||||
parentGallery.removeChild(livePreview)
|
||||
atEnd()
|
||||
}
|
||||
|
||||
var fun = function(id_task, id_live_preview){
|
||||
request("./internal/progress", {"id_task": id_task, "id_live_preview": id_live_preview}, function(res){
|
||||
request("/internal/progress", {"id_task": id_task, "id_live_preview": id_live_preview}, function(res){
|
||||
console.log(res)
|
||||
|
||||
if(res.completed){
|
||||
removeProgressBar()
|
||||
return
|
||||
|
@ -175,7 +155,6 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre
|
|||
progressText = ""
|
||||
|
||||
divInner.style.width = ((res.progress || 0) * 100.0) + '%'
|
||||
divInner.style.background = res.progress ? "" : "transparent"
|
||||
|
||||
if(res.progress > 0){
|
||||
progressText = ((res.progress || 0) * 100.0).toFixed(0) + '%'
|
||||
|
@ -183,13 +162,8 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre
|
|||
|
||||
if(res.eta){
|
||||
progressText += " ETA: " + formatTime(res.eta)
|
||||
}
|
||||
|
||||
|
||||
setTitle(progressText)
|
||||
|
||||
if(res.textinfo && res.textinfo.indexOf("\n") == -1){
|
||||
progressText = res.textinfo + " " + progressText
|
||||
} else if(res.textinfo){
|
||||
progressText += " " + res.textinfo
|
||||
}
|
||||
|
||||
divInner.textContent = progressText
|
||||
|
@ -209,15 +183,16 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre
|
|||
}
|
||||
|
||||
|
||||
if(res.live_preview && gallery){
|
||||
var rect = gallery.getBoundingClientRect()
|
||||
if(rect.width){
|
||||
livePreview.style.width = rect.width + "px"
|
||||
livePreview.style.height = rect.height + "px"
|
||||
}
|
||||
|
||||
if(res.live_preview){
|
||||
var img = new Image();
|
||||
img.onload = function() {
|
||||
var rect = gallery.getBoundingClientRect()
|
||||
if(rect.width){
|
||||
livePreview.style.width = rect.width + "px"
|
||||
livePreview.style.height = rect.height + "px"
|
||||
}
|
||||
|
||||
livePreview.innerHTML = ''
|
||||
livePreview.appendChild(img)
|
||||
if(livePreview.childElementCount > 2){
|
||||
livePreview.removeChild(livePreview.firstElementChild)
|
||||
|
@ -233,7 +208,7 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre
|
|||
|
||||
setTimeout(() => {
|
||||
fun(id_task, res.id_live_preview);
|
||||
}, opts.live_preview_refresh_period || 500)
|
||||
}, 500)
|
||||
}, function(){
|
||||
removeProgressBar()
|
||||
})
|
||||
|
|
|
@ -104,11 +104,9 @@ function create_tab_index_args(tabId, args){
|
|||
return res
|
||||
}
|
||||
|
||||
function get_img2img_tab_index() {
|
||||
let res = args_to_array(arguments)
|
||||
res.splice(-2)
|
||||
res[0] = get_tab_index('mode_img2img')
|
||||
return res
|
||||
function get_extras_tab_index(){
|
||||
const [,,...args] = [...arguments]
|
||||
return [get_tab_index('mode_extras'), get_tab_index('extras_resize_mode'), ...args]
|
||||
}
|
||||
|
||||
function create_submit_args(args){
|
||||
|
@ -167,15 +165,6 @@ function submit_img2img(){
|
|||
return res
|
||||
}
|
||||
|
||||
function modelmerger(){
|
||||
var id = randomId()
|
||||
requestProgress(id, gradioApp().getElementById('modelmerger_results_panel'), null, function(){})
|
||||
|
||||
var res = create_submit_args(arguments)
|
||||
res[0] = id
|
||||
return res
|
||||
}
|
||||
|
||||
|
||||
function ask_for_style_name(_, prompt_text, negative_prompt_text) {
|
||||
name_ = prompt('Style name:')
|
||||
|
@ -192,26 +181,6 @@ function confirm_clear_prompt(prompt, negative_prompt) {
|
|||
}
|
||||
|
||||
|
||||
promptTokecountUpdateFuncs = {}
|
||||
|
||||
function recalculatePromptTokens(name){
|
||||
if(promptTokecountUpdateFuncs[name]){
|
||||
promptTokecountUpdateFuncs[name]()
|
||||
}
|
||||
}
|
||||
|
||||
function recalculate_prompts_txt2img(){
|
||||
recalculatePromptTokens('txt2img_prompt')
|
||||
recalculatePromptTokens('txt2img_neg_prompt')
|
||||
return args_to_array(arguments);
|
||||
}
|
||||
|
||||
function recalculate_prompts_img2img(){
|
||||
recalculatePromptTokens('img2img_prompt')
|
||||
recalculatePromptTokens('img2img_neg_prompt')
|
||||
return args_to_array(arguments);
|
||||
}
|
||||
|
||||
|
||||
opts = {}
|
||||
onUiUpdate(function(){
|
||||
|
@ -220,8 +189,8 @@ onUiUpdate(function(){
|
|||
json_elem = gradioApp().getElementById('settings_json')
|
||||
if(json_elem == null) return;
|
||||
|
||||
var textarea = json_elem.querySelector('textarea')
|
||||
var jsdata = textarea.value
|
||||
textarea = json_elem.querySelector('textarea')
|
||||
jsdata = textarea.value
|
||||
opts = JSON.parse(jsdata)
|
||||
executeCallbacks(optionsChangedCallbacks);
|
||||
|
||||
|
@ -245,27 +214,14 @@ onUiUpdate(function(){
|
|||
|
||||
json_elem.parentElement.style.display="none"
|
||||
|
||||
function registerTextarea(id, id_counter, id_button){
|
||||
var prompt = gradioApp().getElementById(id)
|
||||
var counter = gradioApp().getElementById(id_counter)
|
||||
var textarea = gradioApp().querySelector("#" + id + " > label > textarea");
|
||||
|
||||
if(counter.parentElement == prompt.parentElement){
|
||||
return
|
||||
}
|
||||
|
||||
prompt.parentElement.insertBefore(counter, prompt)
|
||||
counter.classList.add("token-counter")
|
||||
prompt.parentElement.style.position = "relative"
|
||||
|
||||
promptTokecountUpdateFuncs[id] = function(){ update_token_counter(id_button); }
|
||||
textarea.addEventListener("input", promptTokecountUpdateFuncs[id]);
|
||||
}
|
||||
|
||||
registerTextarea('txt2img_prompt', 'txt2img_token_counter', 'txt2img_token_button')
|
||||
registerTextarea('txt2img_neg_prompt', 'txt2img_negative_token_counter', 'txt2img_negative_token_button')
|
||||
registerTextarea('img2img_prompt', 'img2img_token_counter', 'img2img_token_button')
|
||||
registerTextarea('img2img_neg_prompt', 'img2img_negative_token_counter', 'img2img_negative_token_button')
|
||||
if (!txt2img_textarea) {
|
||||
txt2img_textarea = gradioApp().querySelector("#txt2img_prompt > label > textarea");
|
||||
txt2img_textarea?.addEventListener("input", () => update_token_counter("txt2img_token_button"));
|
||||
}
|
||||
if (!img2img_textarea) {
|
||||
img2img_textarea = gradioApp().querySelector("#img2img_prompt > label > textarea");
|
||||
img2img_textarea?.addEventListener("input", () => update_token_counter("img2img_token_button"));
|
||||
}
|
||||
|
||||
show_all_pages = gradioApp().getElementById('settings_show_all_pages')
|
||||
settings_tabs = gradioApp().querySelector('#settings div')
|
||||
|
@ -279,6 +235,7 @@ onUiUpdate(function(){
|
|||
}
|
||||
})
|
||||
|
||||
|
||||
onOptionsChanged(function(){
|
||||
elem = gradioApp().getElementById('sd_checkpoint_hash')
|
||||
sd_checkpoint_hash = opts.sd_checkpoint_hash || ""
|
||||
|
@ -293,7 +250,7 @@ onOptionsChanged(function(){
|
|||
|
||||
let txt2img_textarea, img2img_textarea = undefined;
|
||||
let wait_time = 800
|
||||
let token_timeouts = {};
|
||||
let token_timeout;
|
||||
|
||||
function update_txt2img_tokens(...args) {
|
||||
update_token_counter("txt2img_token_button")
|
||||
|
@ -310,9 +267,9 @@ function update_img2img_tokens(...args) {
|
|||
}
|
||||
|
||||
function update_token_counter(button_id) {
|
||||
if (token_timeouts[button_id])
|
||||
clearTimeout(token_timeouts[button_id]);
|
||||
token_timeouts[button_id] = setTimeout(() => gradioApp().getElementById(button_id)?.click(), wait_time);
|
||||
if (token_timeout)
|
||||
clearTimeout(token_timeout);
|
||||
token_timeout = setTimeout(() => gradioApp().getElementById(button_id)?.click(), wait_time);
|
||||
}
|
||||
|
||||
function restart_reload(){
|
||||
|
@ -321,18 +278,3 @@ function restart_reload(){
|
|||
|
||||
return []
|
||||
}
|
||||
|
||||
// Simulate an `input` DOM event for Gradio Textbox component. Needed after you edit its contents in javascript, otherwise your edits
|
||||
// will only visible on web page and not sent to python.
|
||||
function updateInput(target){
|
||||
let e = new Event("input", { bubbles: true })
|
||||
Object.defineProperty(e, "target", {value: target})
|
||||
target.dispatchEvent(e);
|
||||
}
|
||||
|
||||
|
||||
var desiredCheckpointName = null;
|
||||
function selectCheckpoint(name){
|
||||
desiredCheckpointName = name;
|
||||
gradioApp().getElementById('change_checkpoint').click()
|
||||
}
|
||||
|
|
79
launch.py
79
launch.py
|
@ -14,38 +14,6 @@ python = sys.executable
|
|||
git = os.environ.get('GIT', "git")
|
||||
index_url = os.environ.get('INDEX_URL', "")
|
||||
stored_commit_hash = None
|
||||
skip_install = False
|
||||
|
||||
|
||||
def check_python_version():
|
||||
is_windows = platform.system() == "Windows"
|
||||
major = sys.version_info.major
|
||||
minor = sys.version_info.minor
|
||||
micro = sys.version_info.micro
|
||||
|
||||
if is_windows:
|
||||
supported_minors = [10]
|
||||
else:
|
||||
supported_minors = [7, 8, 9, 10, 11]
|
||||
|
||||
if not (major == 3 and minor in supported_minors):
|
||||
import modules.errors
|
||||
|
||||
modules.errors.print_error_explanation(f"""
|
||||
INCOMPATIBLE PYTHON VERSION
|
||||
|
||||
This program is tested with 3.10.6 Python, but you have {major}.{minor}.{micro}.
|
||||
If you encounter an error with "RuntimeError: Couldn't install torch." message,
|
||||
or any other error regarding unsuccessful package (library) installation,
|
||||
please downgrade (or upgrade) to the latest version of 3.10 Python
|
||||
and delete current Python and "venv" folder in WebUI's directory.
|
||||
|
||||
You can download 3.10 Python from here: https://www.python.org/downloads/release/python-3109/
|
||||
|
||||
{"Alternatively, use a binary release of WebUI: https://github.com/AUTOMATIC1111/stable-diffusion-webui/releases" if is_windows else ""}
|
||||
|
||||
Use --skip-python-version-check to suppress this warning.
|
||||
""")
|
||||
|
||||
|
||||
def commit_hash():
|
||||
|
@ -79,19 +47,10 @@ def extract_opt(args, name):
|
|||
return args, is_present, opt
|
||||
|
||||
|
||||
def run(command, desc=None, errdesc=None, custom_env=None, live=False):
|
||||
def run(command, desc=None, errdesc=None, custom_env=None):
|
||||
if desc is not None:
|
||||
print(desc)
|
||||
|
||||
if live:
|
||||
result = subprocess.run(command, shell=True, env=os.environ if custom_env is None else custom_env)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"""{errdesc or 'Error running command'}.
|
||||
Command: {command}
|
||||
Error code: {result.returncode}""")
|
||||
|
||||
return ""
|
||||
|
||||
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, env=os.environ if custom_env is None else custom_env)
|
||||
|
||||
if result.returncode != 0:
|
||||
|
@ -130,9 +89,6 @@ def run_python(code, desc=None, errdesc=None):
|
|||
|
||||
|
||||
def run_pip(args, desc=None):
|
||||
if skip_install:
|
||||
return
|
||||
|
||||
index_url_line = f' --index-url {index_url}' if index_url != '' else ''
|
||||
return run(f'"{python}" -m pip {args} --prefer-binary{index_url_line}', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}")
|
||||
|
||||
|
@ -148,18 +104,18 @@ def git_clone(url, dir, name, commithash=None):
|
|||
if commithash is None:
|
||||
return
|
||||
|
||||
current_hash = run(f'"{git}" -C "{dir}" rev-parse HEAD', None, f"Couldn't determine {name}'s hash: {commithash}").strip()
|
||||
current_hash = run(f'"{git}" -C {dir} rev-parse HEAD', None, f"Couldn't determine {name}'s hash: {commithash}").strip()
|
||||
if current_hash == commithash:
|
||||
return
|
||||
|
||||
run(f'"{git}" -C "{dir}" fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}")
|
||||
run(f'"{git}" -C "{dir}" checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}")
|
||||
run(f'"{git}" -C {dir} fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}")
|
||||
run(f'"{git}" -C {dir} checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}")
|
||||
return
|
||||
|
||||
run(f'"{git}" clone "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}")
|
||||
|
||||
if commithash is not None:
|
||||
run(f'"{git}" -C "{dir}" checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}")
|
||||
run(f'"{git}" -C {dir} checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}")
|
||||
|
||||
|
||||
def version_check(commit):
|
||||
|
@ -217,17 +173,16 @@ def run_extensions_installers(settings_file):
|
|||
|
||||
|
||||
def prepare_environment():
|
||||
global skip_install
|
||||
|
||||
torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 --extra-index-url https://download.pytorch.org/whl/cu117")
|
||||
torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113")
|
||||
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
|
||||
commandline_args = os.environ.get('COMMANDLINE_ARGS', "")
|
||||
|
||||
xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.16rc425')
|
||||
gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "git+https://github.com/TencentARC/GFPGAN.git@8d2447a2d918f8eba5a4a01463fd48e45126a379")
|
||||
clip_package = os.environ.get('CLIP_PACKAGE', "git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1")
|
||||
openclip_package = os.environ.get('OPENCLIP_PACKAGE', "git+https://github.com/mlfoundations/open_clip.git@bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b")
|
||||
|
||||
xformers_windows_package = os.environ.get('XFORMERS_WINDOWS_PACKAGE', 'https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl')
|
||||
|
||||
stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git")
|
||||
taming_transformers_repo = os.environ.get('TAMING_TRANSFORMERS_REPO', "https://github.com/CompVis/taming-transformers.git")
|
||||
k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git')
|
||||
|
@ -248,25 +203,19 @@ def prepare_environment():
|
|||
|
||||
sys.argv, _ = extract_arg(sys.argv, '-f')
|
||||
sys.argv, skip_torch_cuda_test = extract_arg(sys.argv, '--skip-torch-cuda-test')
|
||||
sys.argv, skip_python_version_check = extract_arg(sys.argv, '--skip-python-version-check')
|
||||
sys.argv, reinstall_xformers = extract_arg(sys.argv, '--reinstall-xformers')
|
||||
sys.argv, reinstall_torch = extract_arg(sys.argv, '--reinstall-torch')
|
||||
sys.argv, update_check = extract_arg(sys.argv, '--update-check')
|
||||
sys.argv, run_tests, test_dir = extract_opt(sys.argv, '--tests')
|
||||
sys.argv, skip_install = extract_arg(sys.argv, '--skip-install')
|
||||
xformers = '--xformers' in sys.argv
|
||||
ngrok = '--ngrok' in sys.argv
|
||||
|
||||
if not skip_python_version_check:
|
||||
check_python_version()
|
||||
|
||||
commit = commit_hash()
|
||||
|
||||
print(f"Python {sys.version}")
|
||||
print(f"Commit hash: {commit}")
|
||||
|
||||
if reinstall_torch or not is_installed("torch") or not is_installed("torchvision"):
|
||||
run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch", live=True)
|
||||
|
||||
if not is_installed("torch") or not is_installed("torchvision"):
|
||||
run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch")
|
||||
|
||||
if not skip_torch_cuda_test:
|
||||
run_python("import torch; assert torch.cuda.is_available(), 'Torch is not able to use GPU; add --skip-torch-cuda-test to COMMANDLINE_ARGS variable to disable this check'")
|
||||
|
@ -283,14 +232,14 @@ def prepare_environment():
|
|||
if (not is_installed("xformers") or reinstall_xformers) and xformers:
|
||||
if platform.system() == "Windows":
|
||||
if platform.python_version().startswith("3.10"):
|
||||
run_pip(f"install -U -I --no-deps {xformers_package}", "xformers")
|
||||
run_pip(f"install -U -I --no-deps {xformers_windows_package}", "xformers")
|
||||
else:
|
||||
print("Installation of xformers is not supported in this version of Python.")
|
||||
print("You can also check this and build manually: https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers#building-xformers-on-windows-by-duckness")
|
||||
if not is_installed("xformers"):
|
||||
exit(0)
|
||||
elif platform.system() == "Linux":
|
||||
run_pip(f"install {xformers_package}", "xformers")
|
||||
run_pip("install xformers", "xformers")
|
||||
|
||||
if not is_installed("pyngrok") and ngrok:
|
||||
run_pip("install pyngrok", "ngrok")
|
||||
|
@ -330,8 +279,6 @@ def tests(test_dir):
|
|||
sys.argv.append("./test/test_files/empty.pt")
|
||||
if "--skip-torch-cuda-test" not in sys.argv:
|
||||
sys.argv.append("--skip-torch-cuda-test")
|
||||
if "--disable-nan-check" not in sys.argv:
|
||||
sys.argv.append("--disable-nan-check")
|
||||
|
||||
print(f"Launching Web UI in another process for testing with arguments: {' '.join(sys.argv[1:])}")
|
||||
|
||||
|
|
|
@ -11,20 +11,18 @@ from fastapi.security import HTTPBasic, HTTPBasicCredentials
|
|||
from secrets import compare_digest
|
||||
|
||||
import modules.shared as shared
|
||||
from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing
|
||||
from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui
|
||||
from modules.api.models import *
|
||||
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
|
||||
from modules.extras import run_extras
|
||||
from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
|
||||
from modules.textual_inversion.preprocess import preprocess
|
||||
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
|
||||
from PIL import PngImagePlugin,Image
|
||||
from modules.sd_models import checkpoints_list
|
||||
from modules.sd_models_config import find_checkpoint_config_near_filename
|
||||
from modules.sd_models import checkpoints_list, find_checkpoint_config
|
||||
from modules.realesrgan_model import get_realesrgan_models
|
||||
from modules import devices
|
||||
from typing import List
|
||||
import piexif
|
||||
import piexif.helper
|
||||
|
||||
def upscaler_to_index(name: str):
|
||||
try:
|
||||
|
@ -47,46 +45,32 @@ def validate_sampler_name(name):
|
|||
|
||||
def setUpscalers(req: dict):
|
||||
reqDict = vars(req)
|
||||
reqDict['extras_upscaler_1'] = reqDict.pop('upscaler_1', None)
|
||||
reqDict['extras_upscaler_2'] = reqDict.pop('upscaler_2', None)
|
||||
reqDict['extras_upscaler_1'] = upscaler_to_index(req.upscaler_1)
|
||||
reqDict['extras_upscaler_2'] = upscaler_to_index(req.upscaler_2)
|
||||
reqDict.pop('upscaler_1')
|
||||
reqDict.pop('upscaler_2')
|
||||
return reqDict
|
||||
|
||||
def decode_base64_to_image(encoding):
|
||||
if encoding.startswith("data:image/"):
|
||||
encoding = encoding.split(";")[1].split(",")[1]
|
||||
try:
|
||||
image = Image.open(BytesIO(base64.b64decode(encoding)))
|
||||
return image
|
||||
except Exception as err:
|
||||
raise HTTPException(status_code=500, detail="Invalid encoded image")
|
||||
return Image.open(BytesIO(base64.b64decode(encoding)))
|
||||
|
||||
def encode_pil_to_base64(image):
|
||||
with io.BytesIO() as output_bytes:
|
||||
|
||||
if opts.samples_format.lower() == 'png':
|
||||
use_metadata = False
|
||||
metadata = PngImagePlugin.PngInfo()
|
||||
for key, value in image.info.items():
|
||||
if isinstance(key, str) and isinstance(value, str):
|
||||
metadata.add_text(key, value)
|
||||
use_metadata = True
|
||||
image.save(output_bytes, format="PNG", pnginfo=(metadata if use_metadata else None), quality=opts.jpeg_quality)
|
||||
|
||||
elif opts.samples_format.lower() in ("jpg", "jpeg", "webp"):
|
||||
parameters = image.info.get('parameters', None)
|
||||
exif_bytes = piexif.dump({
|
||||
"Exif": { piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(parameters or "", encoding="unicode") }
|
||||
})
|
||||
if opts.samples_format.lower() in ("jpg", "jpeg"):
|
||||
image.save(output_bytes, format="JPEG", exif = exif_bytes, quality=opts.jpeg_quality)
|
||||
else:
|
||||
image.save(output_bytes, format="WEBP", exif = exif_bytes, quality=opts.jpeg_quality)
|
||||
|
||||
else:
|
||||
raise HTTPException(status_code=500, detail="Invalid image format")
|
||||
# Copy any text-only metadata
|
||||
use_metadata = False
|
||||
metadata = PngImagePlugin.PngInfo()
|
||||
for key, value in image.info.items():
|
||||
if isinstance(key, str) and isinstance(value, str):
|
||||
metadata.add_text(key, value)
|
||||
use_metadata = True
|
||||
|
||||
image.save(
|
||||
output_bytes, "PNG", pnginfo=(metadata if use_metadata else None)
|
||||
)
|
||||
bytes_data = output_bytes.getvalue()
|
||||
|
||||
return base64.b64encode(bytes_data)
|
||||
|
||||
def api_middleware(app: FastAPI):
|
||||
|
@ -142,6 +126,8 @@ class Api:
|
|||
self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[FaceRestorerItem])
|
||||
self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[RealesrganItem])
|
||||
self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[PromptStyleItem])
|
||||
self.add_api_route("/sdapi/v1/artist-categories", self.get_artists_categories, methods=["GET"], response_model=List[str])
|
||||
self.add_api_route("/sdapi/v1/artists", self.get_artists, methods=["GET"], response_model=List[ArtistItem])
|
||||
self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=EmbeddingsResponse)
|
||||
self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
|
||||
self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=CreateResponse)
|
||||
|
@ -260,7 +246,7 @@ class Api:
|
|||
reqDict['image'] = decode_base64_to_image(reqDict['image'])
|
||||
|
||||
with self.queue_lock:
|
||||
result = postprocessing.run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", save_output=False, **reqDict)
|
||||
result = run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", save_output=False, **reqDict)
|
||||
|
||||
return ExtrasSingleImageResponse(image=encode_pil_to_base64(result[0][0]), html_info=result[1])
|
||||
|
||||
|
@ -276,7 +262,7 @@ class Api:
|
|||
reqDict.pop('imageList')
|
||||
|
||||
with self.queue_lock:
|
||||
result = postprocessing.run_extras(extras_mode=1, image="", input_dir="", output_dir="", save_output=False, **reqDict)
|
||||
result = run_extras(extras_mode=1, image="", input_dir="", output_dir="", save_output=False, **reqDict)
|
||||
|
||||
return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
|
||||
|
||||
|
@ -376,19 +362,16 @@ class Api:
|
|||
return [{"name": sampler[0], "aliases":sampler[2], "options":sampler[3]} for sampler in sd_samplers.all_samplers]
|
||||
|
||||
def get_upscalers(self):
|
||||
return [
|
||||
{
|
||||
"name": upscaler.name,
|
||||
"model_name": upscaler.scaler.model_name,
|
||||
"model_path": upscaler.data_path,
|
||||
"model_url": None,
|
||||
"scale": upscaler.scale,
|
||||
}
|
||||
for upscaler in shared.sd_upscalers
|
||||
]
|
||||
upscalers = []
|
||||
|
||||
for upscaler in shared.sd_upscalers:
|
||||
u = upscaler.scaler
|
||||
upscalers.append({"name":u.name, "model_name":u.model_name, "model_path":u.model_path, "model_url":u.model_url})
|
||||
|
||||
return upscalers
|
||||
|
||||
def get_sd_models(self):
|
||||
return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config_near_filename(x)} for x in checkpoints_list.values()]
|
||||
return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config(x)} for x in checkpoints_list.values()]
|
||||
|
||||
def get_hypernetworks(self):
|
||||
return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks]
|
||||
|
@ -407,6 +390,12 @@ class Api:
|
|||
|
||||
return styleList
|
||||
|
||||
def get_artists_categories(self):
|
||||
return shared.artist_db.cats
|
||||
|
||||
def get_artists(self):
|
||||
return [{"name":x[0], "score":x[1], "category":x[2]} for x in shared.artist_db.artists]
|
||||
|
||||
def get_embeddings(self):
|
||||
db = sd_hijack.model_hijack.embedding_db
|
||||
|
||||
|
@ -491,7 +480,7 @@ class Api:
|
|||
def train_hypernetwork(self, args: dict):
|
||||
try:
|
||||
shared.state.begin()
|
||||
shared.loaded_hypernetworks = []
|
||||
initial_hypernetwork = shared.loaded_hypernetwork
|
||||
apply_optimizations = shared.opts.training_xattention_optimizations
|
||||
error = None
|
||||
filename = ''
|
||||
|
@ -502,15 +491,16 @@ class Api:
|
|||
except Exception as e:
|
||||
error = e
|
||||
finally:
|
||||
shared.loaded_hypernetwork = initial_hypernetwork
|
||||
shared.sd_model.cond_stage_model.to(devices.device)
|
||||
shared.sd_model.first_stage_model.to(devices.device)
|
||||
if not apply_optimizations:
|
||||
sd_hijack.apply_optimizations()
|
||||
shared.state.end()
|
||||
return TrainResponse(info="train embedding complete: filename: {filename} error: {error}".format(filename=filename, error=error))
|
||||
return TrainResponse(info = "train embedding complete: filename: {filename} error: {error}".format(filename = filename, error = error))
|
||||
except AssertionError as msg:
|
||||
shared.state.end()
|
||||
return TrainResponse(info="train embedding error: {error}".format(error=error))
|
||||
return TrainResponse(info = "train embedding error: {error}".format(error = error))
|
||||
|
||||
def get_memory(self):
|
||||
try:
|
||||
|
|
|
@ -220,7 +220,6 @@ class UpscalerItem(BaseModel):
|
|||
model_name: Optional[str] = Field(title="Model Name")
|
||||
model_path: Optional[str] = Field(title="Path")
|
||||
model_url: Optional[str] = Field(title="URL")
|
||||
scale: Optional[float] = Field(title="Scale")
|
||||
|
||||
class SDModelItem(BaseModel):
|
||||
title: str = Field(title="Title")
|
||||
|
@ -228,7 +227,7 @@ class SDModelItem(BaseModel):
|
|||
hash: Optional[str] = Field(title="Short hash")
|
||||
sha256: Optional[str] = Field(title="sha256 hash")
|
||||
filename: str = Field(title="Filename")
|
||||
config: Optional[str] = Field(title="Config file")
|
||||
config: str = Field(title="Config file")
|
||||
|
||||
class HypernetworkItem(BaseModel):
|
||||
name: str = Field(title="Name")
|
||||
|
|
25
modules/artists.py
Normal file
25
modules/artists.py
Normal file
|
@ -0,0 +1,25 @@
|
|||
import os.path
|
||||
import csv
|
||||
from collections import namedtuple
|
||||
|
||||
Artist = namedtuple("Artist", ['name', 'weight', 'category'])
|
||||
|
||||
|
||||
class ArtistsDatabase:
|
||||
def __init__(self, filename):
|
||||
self.cats = set()
|
||||
self.artists = []
|
||||
|
||||
if not os.path.exists(filename):
|
||||
return
|
||||
|
||||
with open(filename, "r", newline='', encoding="utf8") as file:
|
||||
reader = csv.DictReader(file)
|
||||
|
||||
for row in reader:
|
||||
artist = Artist(row["artist"], float(row["score"]), row["category"])
|
||||
self.artists.append(artist)
|
||||
self.cats.add(artist.category)
|
||||
|
||||
def categories(self):
|
||||
return sorted(self.cats)
|
|
@ -8,7 +8,7 @@ import torch
|
|||
import modules.face_restoration
|
||||
import modules.shared
|
||||
from modules import shared, devices, modelloader
|
||||
from modules.paths import models_path
|
||||
from modules.paths import script_path, models_path
|
||||
|
||||
# codeformer people made a choice to include modified basicsr library to their project which makes
|
||||
# it utterly impossible to use it alongside with other libraries that also use basicsr, like GFPGAN.
|
||||
|
|
|
@ -2,8 +2,6 @@ import torch
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from modules import devices
|
||||
|
||||
# see https://github.com/AUTOMATIC1111/TorchDeepDanbooru for more
|
||||
|
||||
|
||||
|
@ -198,7 +196,7 @@ class DeepDanbooruModel(nn.Module):
|
|||
t_358, = inputs
|
||||
t_359 = t_358.permute(*[0, 3, 1, 2])
|
||||
t_359_padded = F.pad(t_359, [2, 3, 2, 3], value=0)
|
||||
t_360 = self.n_Conv_0(t_359_padded.to(self.n_Conv_0.bias.dtype) if devices.unet_needs_upcast else t_359_padded)
|
||||
t_360 = self.n_Conv_0(t_359_padded)
|
||||
t_361 = F.relu(t_360)
|
||||
t_361 = F.pad(t_361, [0, 1, 0, 1], value=float('-inf'))
|
||||
t_362 = self.n_MaxPool_0(t_361)
|
||||
|
|
|
@ -16,10 +16,6 @@ def has_mps() -> bool:
|
|||
except Exception:
|
||||
return False
|
||||
|
||||
def has_dml():
|
||||
import importlib
|
||||
loader = importlib.find_loader('torch_directml')
|
||||
return loader is not None
|
||||
|
||||
def extract_device_id(args, name):
|
||||
for x in range(len(args)):
|
||||
|
@ -38,25 +34,14 @@ def get_cuda_device_string():
|
|||
return "cuda"
|
||||
|
||||
|
||||
def get_optimal_device_name():
|
||||
if has_dml():
|
||||
return "dml"
|
||||
def get_optimal_device():
|
||||
if torch.cuda.is_available():
|
||||
return torch.device(get_cuda_device_string())
|
||||
|
||||
if has_mps():
|
||||
return "mps"
|
||||
return torch.device("mps")
|
||||
|
||||
if torch.cuda.is_available():
|
||||
return get_cuda_device_string()
|
||||
|
||||
return "cpu"
|
||||
|
||||
|
||||
def get_optimal_device():
|
||||
if get_optimal_device_name() == "dml":
|
||||
import torch_directml
|
||||
return torch_directml.device()
|
||||
|
||||
return torch.device(get_optimal_device_name())
|
||||
return cpu
|
||||
|
||||
|
||||
def get_device_for(task):
|
||||
|
@ -94,16 +79,6 @@ cpu = torch.device("cpu")
|
|||
device = device_interrogate = device_gfpgan = device_esrgan = device_codeformer = None
|
||||
dtype = torch.float16
|
||||
dtype_vae = torch.float16
|
||||
dtype_unet = torch.float16
|
||||
unet_needs_upcast = False
|
||||
|
||||
|
||||
def cond_cast_unet(input):
|
||||
return input.to(dtype_unet) if unet_needs_upcast else input
|
||||
|
||||
|
||||
def cond_cast_float(input):
|
||||
return input.float() if unet_needs_upcast else input
|
||||
|
||||
|
||||
def randn(seed, shape):
|
||||
|
@ -131,42 +106,6 @@ def autocast(disable=False):
|
|||
return torch.autocast("cuda")
|
||||
|
||||
|
||||
def without_autocast(disable=False):
|
||||
return torch.autocast("cuda", enabled=False) if torch.is_autocast_enabled() and not disable else contextlib.nullcontext()
|
||||
|
||||
|
||||
class NansException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def test_for_nans(x, where):
|
||||
from modules import shared
|
||||
|
||||
if shared.cmd_opts.disable_nan_check:
|
||||
return
|
||||
|
||||
if not torch.all(torch.isnan(x)).item():
|
||||
return
|
||||
|
||||
if where == "unet":
|
||||
message = "A tensor with all NaNs was produced in Unet."
|
||||
|
||||
if not shared.cmd_opts.no_half:
|
||||
message += " This could be either because there's not enough precision to represent the picture, or because your video card does not support half type. Try setting the \"Upcast cross attention layer to float32\" option in Settings > Stable Diffusion or using the --no-half commandline argument to fix this."
|
||||
|
||||
elif where == "vae":
|
||||
message = "A tensor with all NaNs was produced in VAE."
|
||||
|
||||
if not shared.cmd_opts.no_half and not shared.cmd_opts.no_half_vae:
|
||||
message += " This could be because there's not enough precision to represent the picture. Try adding --no-half-vae commandline argument to fix this."
|
||||
else:
|
||||
message = "A tensor with all NaNs was produced."
|
||||
|
||||
message += " Use --disable-nan-check commandline argument to disable this check."
|
||||
|
||||
raise NansException(message)
|
||||
|
||||
|
||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/79383
|
||||
orig_tensor_to = torch.Tensor.to
|
||||
def tensor_to_fix(self, *args, **kwargs):
|
||||
|
@ -200,10 +139,8 @@ orig_Tensor_cumsum = torch.Tensor.cumsum
|
|||
def cumsum_fix(input, cumsum_func, *args, **kwargs):
|
||||
if input.device.type == 'mps':
|
||||
output_dtype = kwargs.get('dtype', input.dtype)
|
||||
if output_dtype == torch.int64:
|
||||
if any(output_dtype == broken_dtype for broken_dtype in [torch.bool, torch.int8, torch.int16, torch.int64]):
|
||||
return cumsum_func(input.cpu(), *args, **kwargs).to(input.device)
|
||||
elif cumsum_needs_bool_fix and output_dtype == torch.bool or cumsum_needs_int_fix and (output_dtype == torch.int8 or output_dtype == torch.int16):
|
||||
return cumsum_func(input.to(torch.int32), *args, **kwargs).to(torch.int64)
|
||||
return cumsum_func(input, *args, **kwargs)
|
||||
|
||||
|
||||
|
@ -214,26 +151,8 @@ if has_mps():
|
|||
torch.nn.functional.layer_norm = layer_norm_fix
|
||||
torch.Tensor.numpy = numpy_fix
|
||||
elif version.parse(torch.__version__) > version.parse("1.13.1"):
|
||||
cumsum_needs_int_fix = not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.ShortTensor([1,1]).to(torch.device("mps")).cumsum(0))
|
||||
cumsum_needs_bool_fix = not torch.BoolTensor([True,True]).to(device=torch.device("mps"), dtype=torch.int64).equal(torch.BoolTensor([True,False]).to(torch.device("mps")).cumsum(0))
|
||||
torch.cumsum = lambda input, *args, **kwargs: ( cumsum_fix(input, orig_cumsum, *args, **kwargs) )
|
||||
torch.Tensor.cumsum = lambda self, *args, **kwargs: ( cumsum_fix(self, orig_Tensor_cumsum, *args, **kwargs) )
|
||||
|
||||
if has_dml():
|
||||
_cumsum = torch.cumsum
|
||||
_repeat_interleave = torch.repeat_interleave
|
||||
_multinomial = torch.multinomial
|
||||
|
||||
_Tensor_new = torch.Tensor.new
|
||||
_Tensor_cumsum = torch.Tensor.cumsum
|
||||
_Tensor_repeat_interleave = torch.Tensor.repeat_interleave
|
||||
_Tensor_multinomial = torch.Tensor.multinomial
|
||||
|
||||
torch.cumsum = lambda input, *args, **kwargs: ( _cumsum(input.to("cpu"), *args, **kwargs).to(input.device) )
|
||||
torch.repeat_interleave = lambda input, *args, **kwargs: ( _repeat_interleave(input.to("cpu"), *args, **kwargs).to(input.device) )
|
||||
torch.multinomial = lambda input, *args, **kwargs: ( _multinomial(input.to("cpu"), *args, **kwargs).to(input.device) )
|
||||
|
||||
torch.Tensor.new = lambda self, *args, **kwargs: ( _Tensor_new(self.to("cpu"), *args, **kwargs).to(self.device) )
|
||||
torch.Tensor.cumsum = lambda self, *args, **kwargs: ( _Tensor_cumsum(self.to("cpu"), *args, **kwargs).to(self.device) )
|
||||
torch.Tensor.repeat_interleave = lambda self, *args, **kwargs: ( _Tensor_repeat_interleave(self.to("cpu"), *args, **kwargs).to(self.device) )
|
||||
torch.Tensor.multinomial = lambda self, *args, **kwargs: ( _Tensor_multinomial(self.to("cpu"), *args, **kwargs).to(self.device) )
|
||||
if not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.Tensor([1,1]).to(torch.device("mps")).cumsum(0, dtype=torch.int16)):
|
||||
torch.cumsum = lambda input, *args, **kwargs: ( cumsum_fix(input, orig_cumsum, *args, **kwargs) )
|
||||
torch.Tensor.cumsum = lambda self, *args, **kwargs: ( cumsum_fix(self, orig_Tensor_cumsum, *args, **kwargs) )
|
||||
orig_narrow = torch.narrow
|
||||
torch.narrow = lambda *args, **kwargs: ( orig_narrow(*args, **kwargs).clone() )
|
||||
|
|
|
@ -19,23 +19,11 @@ def display(e: Exception, task):
|
|||
message = str(e)
|
||||
if "copying a param with shape torch.Size([640, 1024]) from checkpoint, the shape in current model is torch.Size([640, 768])" in message:
|
||||
print_error_explanation("""
|
||||
The most likely cause of this is you are trying to load Stable Diffusion 2.0 model without specifying its config file.
|
||||
The most likely cause of this is you are trying to load Stable Diffusion 2.0 model without specifying its connfig file.
|
||||
See https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#stable-diffusion-20 for how to solve this.
|
||||
""")
|
||||
|
||||
|
||||
already_displayed = {}
|
||||
|
||||
|
||||
def display_once(e: Exception, task):
|
||||
if task in already_displayed:
|
||||
return
|
||||
|
||||
display(e, task)
|
||||
|
||||
already_displayed[task] = 1
|
||||
|
||||
|
||||
def run(code, task):
|
||||
try:
|
||||
code()
|
||||
|
|
|
@ -7,11 +7,9 @@ import git
|
|||
from modules import paths, shared
|
||||
|
||||
extensions = []
|
||||
extensions_dir = os.path.join(paths.data_path, "extensions")
|
||||
extensions_dir = os.path.join(paths.script_path, "extensions")
|
||||
extensions_builtin_dir = os.path.join(paths.script_path, "extensions-builtin")
|
||||
|
||||
if not os.path.exists(extensions_dir):
|
||||
os.makedirs(extensions_dir)
|
||||
|
||||
def active():
|
||||
return [x for x in extensions if x.enabled]
|
||||
|
|
|
@ -1,147 +0,0 @@
|
|||
import re
|
||||
from collections import defaultdict
|
||||
|
||||
from modules import errors
|
||||
|
||||
extra_network_registry = {}
|
||||
|
||||
|
||||
def initialize():
|
||||
extra_network_registry.clear()
|
||||
|
||||
|
||||
def register_extra_network(extra_network):
|
||||
extra_network_registry[extra_network.name] = extra_network
|
||||
|
||||
|
||||
class ExtraNetworkParams:
|
||||
def __init__(self, items=None):
|
||||
self.items = items or []
|
||||
|
||||
|
||||
class ExtraNetwork:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
def activate(self, p, params_list):
|
||||
"""
|
||||
Called by processing on every run. Whatever the extra network is meant to do should be activated here.
|
||||
Passes arguments related to this extra network in params_list.
|
||||
User passes arguments by specifying this in his prompt:
|
||||
|
||||
<name:arg1:arg2:arg3>
|
||||
|
||||
Where name matches the name of this ExtraNetwork object, and arg1:arg2:arg3 are any natural number of text arguments
|
||||
separated by colon.
|
||||
|
||||
Even if the user does not mention this ExtraNetwork in his prompt, the call will stil be made, with empty params_list -
|
||||
in this case, all effects of this extra networks should be disabled.
|
||||
|
||||
Can be called multiple times before deactivate() - each new call should override the previous call completely.
|
||||
|
||||
For example, if this ExtraNetwork's name is 'hypernet' and user's prompt is:
|
||||
|
||||
> "1girl, <hypernet:agm:1.1> <extrasupernet:master:12:13:14> <hypernet:ray>"
|
||||
|
||||
params_list will be:
|
||||
|
||||
[
|
||||
ExtraNetworkParams(items=["agm", "1.1"]),
|
||||
ExtraNetworkParams(items=["ray"])
|
||||
]
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def deactivate(self, p):
|
||||
"""
|
||||
Called at the end of processing for housekeeping. No need to do anything here.
|
||||
"""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def activate(p, extra_network_data):
|
||||
"""call activate for extra networks in extra_network_data in specified order, then call
|
||||
activate for all remaining registered networks with an empty argument list"""
|
||||
|
||||
for extra_network_name, extra_network_args in extra_network_data.items():
|
||||
extra_network = extra_network_registry.get(extra_network_name, None)
|
||||
if extra_network is None:
|
||||
print(f"Skipping unknown extra network: {extra_network_name}")
|
||||
continue
|
||||
|
||||
try:
|
||||
extra_network.activate(p, extra_network_args)
|
||||
except Exception as e:
|
||||
errors.display(e, f"activating extra network {extra_network_name} with arguments {extra_network_args}")
|
||||
|
||||
for extra_network_name, extra_network in extra_network_registry.items():
|
||||
args = extra_network_data.get(extra_network_name, None)
|
||||
if args is not None:
|
||||
continue
|
||||
|
||||
try:
|
||||
extra_network.activate(p, [])
|
||||
except Exception as e:
|
||||
errors.display(e, f"activating extra network {extra_network_name}")
|
||||
|
||||
|
||||
def deactivate(p, extra_network_data):
|
||||
"""call deactivate for extra networks in extra_network_data in specified order, then call
|
||||
deactivate for all remaining registered networks"""
|
||||
|
||||
for extra_network_name, extra_network_args in extra_network_data.items():
|
||||
extra_network = extra_network_registry.get(extra_network_name, None)
|
||||
if extra_network is None:
|
||||
continue
|
||||
|
||||
try:
|
||||
extra_network.deactivate(p)
|
||||
except Exception as e:
|
||||
errors.display(e, f"deactivating extra network {extra_network_name}")
|
||||
|
||||
for extra_network_name, extra_network in extra_network_registry.items():
|
||||
args = extra_network_data.get(extra_network_name, None)
|
||||
if args is not None:
|
||||
continue
|
||||
|
||||
try:
|
||||
extra_network.deactivate(p)
|
||||
except Exception as e:
|
||||
errors.display(e, f"deactivating unmentioned extra network {extra_network_name}")
|
||||
|
||||
|
||||
re_extra_net = re.compile(r"<(\w+):([^>]+)>")
|
||||
|
||||
|
||||
def parse_prompt(prompt):
|
||||
res = defaultdict(list)
|
||||
|
||||
def found(m):
|
||||
name = m.group(1)
|
||||
args = m.group(2)
|
||||
|
||||
res[name].append(ExtraNetworkParams(items=args.split(":")))
|
||||
|
||||
return ""
|
||||
|
||||
prompt = re.sub(re_extra_net, found, prompt)
|
||||
|
||||
return prompt, res
|
||||
|
||||
|
||||
def parse_prompts(prompts):
|
||||
res = []
|
||||
extra_data = None
|
||||
|
||||
for prompt in prompts:
|
||||
updated_prompt, parsed_extra_data = parse_prompt(prompt)
|
||||
|
||||
if extra_data is None:
|
||||
extra_data = parsed_extra_data
|
||||
|
||||
res.append(updated_prompt)
|
||||
|
||||
return res, extra_data
|
||||
|
|
@ -1,27 +0,0 @@
|
|||
from modules import extra_networks, shared, extra_networks
|
||||
from modules.hypernetworks import hypernetwork
|
||||
|
||||
|
||||
class ExtraNetworkHypernet(extra_networks.ExtraNetwork):
|
||||
def __init__(self):
|
||||
super().__init__('hypernet')
|
||||
|
||||
def activate(self, p, params_list):
|
||||
additional = shared.opts.sd_hypernetwork
|
||||
|
||||
if additional != "" and additional in shared.hypernetworks and len([x for x in params_list if x.items[0] == additional]) == 0:
|
||||
p.all_prompts = [x + f"<hypernet:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts]
|
||||
params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
|
||||
|
||||
names = []
|
||||
multipliers = []
|
||||
for params in params_list:
|
||||
assert len(params.items) > 0
|
||||
|
||||
names.append(params.items[0])
|
||||
multipliers.append(float(params.items[1]) if len(params.items) > 1 else 1.0)
|
||||
|
||||
hypernetwork.load_hypernetworks(names, multipliers)
|
||||
|
||||
def deactivate(self, p):
|
||||
pass
|
|
@ -1,16 +1,230 @@
|
|||
from __future__ import annotations
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import traceback
|
||||
import shutil
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
import torch
|
||||
import tqdm
|
||||
|
||||
from modules import shared, images, sd_models, sd_vae, sd_models_config
|
||||
from modules.ui_common import plaintext_to_html
|
||||
from typing import Callable, List, OrderedDict, Tuple
|
||||
from functools import partial
|
||||
from dataclasses import dataclass
|
||||
|
||||
from modules import processing, shared, images, devices, sd_models, sd_samplers
|
||||
from modules.shared import opts
|
||||
import modules.gfpgan_model
|
||||
from modules.ui import plaintext_to_html
|
||||
import modules.codeformer_model
|
||||
import gradio as gr
|
||||
import safetensors.torch
|
||||
|
||||
class LruCache(OrderedDict):
|
||||
@dataclass(frozen=True)
|
||||
class Key:
|
||||
image_hash: int
|
||||
info_hash: int
|
||||
args_hash: int
|
||||
|
||||
@dataclass
|
||||
class Value:
|
||||
image: Image.Image
|
||||
info: str
|
||||
|
||||
def __init__(self, max_size: int = 5, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._max_size = max_size
|
||||
|
||||
def get(self, key: LruCache.Key) -> LruCache.Value:
|
||||
ret = super().get(key)
|
||||
if ret is not None:
|
||||
self.move_to_end(key) # Move to end of eviction list
|
||||
return ret
|
||||
|
||||
def put(self, key: LruCache.Key, value: LruCache.Value) -> None:
|
||||
self[key] = value
|
||||
while len(self) > self._max_size:
|
||||
self.popitem(last=False)
|
||||
|
||||
|
||||
cached_images: LruCache = LruCache(max_size=5)
|
||||
|
||||
|
||||
def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True):
|
||||
devices.torch_gc()
|
||||
|
||||
shared.state.begin()
|
||||
shared.state.job = 'extras'
|
||||
|
||||
imageArr = []
|
||||
# Also keep track of original file names
|
||||
imageNameArr = []
|
||||
outputs = []
|
||||
|
||||
if extras_mode == 1:
|
||||
#convert file to pillow image
|
||||
for img in image_folder:
|
||||
image = Image.open(img)
|
||||
imageArr.append(image)
|
||||
imageNameArr.append(os.path.splitext(img.orig_name)[0])
|
||||
elif extras_mode == 2:
|
||||
assert not shared.cmd_opts.hide_ui_dir_config, '--hide-ui-dir-config option must be disabled'
|
||||
|
||||
if input_dir == '':
|
||||
return outputs, "Please select an input directory.", ''
|
||||
image_list = shared.listfiles(input_dir)
|
||||
for img in image_list:
|
||||
try:
|
||||
image = Image.open(img)
|
||||
except Exception:
|
||||
continue
|
||||
imageArr.append(image)
|
||||
imageNameArr.append(img)
|
||||
else:
|
||||
imageArr.append(image)
|
||||
imageNameArr.append(None)
|
||||
|
||||
if extras_mode == 2 and output_dir != '':
|
||||
outpath = output_dir
|
||||
else:
|
||||
outpath = opts.outdir_samples or opts.outdir_extras_samples
|
||||
|
||||
# Extra operation definitions
|
||||
|
||||
def run_gfpgan(image: Image.Image, info: str) -> Tuple[Image.Image, str]:
|
||||
shared.state.job = 'extras-gfpgan'
|
||||
restored_img = modules.gfpgan_model.gfpgan_fix_faces(np.array(image, dtype=np.uint8))
|
||||
res = Image.fromarray(restored_img)
|
||||
|
||||
if gfpgan_visibility < 1.0:
|
||||
res = Image.blend(image, res, gfpgan_visibility)
|
||||
|
||||
info += f"GFPGAN visibility:{round(gfpgan_visibility, 2)}\n"
|
||||
return (res, info)
|
||||
|
||||
def run_codeformer(image: Image.Image, info: str) -> Tuple[Image.Image, str]:
|
||||
shared.state.job = 'extras-codeformer'
|
||||
restored_img = modules.codeformer_model.codeformer.restore(np.array(image, dtype=np.uint8), w=codeformer_weight)
|
||||
res = Image.fromarray(restored_img)
|
||||
|
||||
if codeformer_visibility < 1.0:
|
||||
res = Image.blend(image, res, codeformer_visibility)
|
||||
|
||||
info += f"CodeFormer w: {round(codeformer_weight, 2)}, CodeFormer visibility:{round(codeformer_visibility, 2)}\n"
|
||||
return (res, info)
|
||||
|
||||
def upscale(image, scaler_index, resize, mode, resize_w, resize_h, crop):
|
||||
shared.state.job = 'extras-upscale'
|
||||
upscaler = shared.sd_upscalers[scaler_index]
|
||||
res = upscaler.scaler.upscale(image, resize, upscaler.data_path)
|
||||
if mode == 1 and crop:
|
||||
cropped = Image.new("RGB", (resize_w, resize_h))
|
||||
cropped.paste(res, box=(resize_w // 2 - res.width // 2, resize_h // 2 - res.height // 2))
|
||||
res = cropped
|
||||
return res
|
||||
|
||||
def run_prepare_crop(image: Image.Image, info: str) -> Tuple[Image.Image, str]:
|
||||
# Actual crop happens in run_upscalers_blend, this just sets upscaling_resize and adds info text
|
||||
nonlocal upscaling_resize
|
||||
if resize_mode == 1:
|
||||
upscaling_resize = max(upscaling_resize_w/image.width, upscaling_resize_h/image.height)
|
||||
crop_info = " (crop)" if upscaling_crop else ""
|
||||
info += f"Resize to: {upscaling_resize_w:g}x{upscaling_resize_h:g}{crop_info}\n"
|
||||
return (image, info)
|
||||
|
||||
@dataclass
|
||||
class UpscaleParams:
|
||||
upscaler_idx: int
|
||||
blend_alpha: float
|
||||
|
||||
def run_upscalers_blend(params: List[UpscaleParams], image: Image.Image, info: str) -> Tuple[Image.Image, str]:
|
||||
blended_result: Image.Image = None
|
||||
image_hash: str = hash(np.array(image.getdata()).tobytes())
|
||||
for upscaler in params:
|
||||
upscale_args = (upscaler.upscaler_idx, upscaling_resize, resize_mode,
|
||||
upscaling_resize_w, upscaling_resize_h, upscaling_crop)
|
||||
cache_key = LruCache.Key(image_hash=image_hash,
|
||||
info_hash=hash(info),
|
||||
args_hash=hash(upscale_args))
|
||||
cached_entry = cached_images.get(cache_key)
|
||||
if cached_entry is None:
|
||||
res = upscale(image, *upscale_args)
|
||||
info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {upscaler.blend_alpha}, model:{shared.sd_upscalers[upscaler.upscaler_idx].name}\n"
|
||||
cached_images.put(cache_key, LruCache.Value(image=res, info=info))
|
||||
else:
|
||||
res, info = cached_entry.image, cached_entry.info
|
||||
|
||||
if blended_result is None:
|
||||
blended_result = res
|
||||
else:
|
||||
blended_result = Image.blend(blended_result, res, upscaler.blend_alpha)
|
||||
return (blended_result, info)
|
||||
|
||||
# Build a list of operations to run
|
||||
facefix_ops: List[Callable] = []
|
||||
facefix_ops += [run_gfpgan] if gfpgan_visibility > 0 else []
|
||||
facefix_ops += [run_codeformer] if codeformer_visibility > 0 else []
|
||||
|
||||
upscale_ops: List[Callable] = []
|
||||
upscale_ops += [run_prepare_crop] if resize_mode == 1 else []
|
||||
|
||||
if upscaling_resize != 0:
|
||||
step_params: List[UpscaleParams] = []
|
||||
step_params.append(UpscaleParams(upscaler_idx=extras_upscaler_1, blend_alpha=1.0))
|
||||
if extras_upscaler_2 != 0 and extras_upscaler_2_visibility > 0:
|
||||
step_params.append(UpscaleParams(upscaler_idx=extras_upscaler_2, blend_alpha=extras_upscaler_2_visibility))
|
||||
|
||||
upscale_ops.append(partial(run_upscalers_blend, step_params))
|
||||
|
||||
extras_ops: List[Callable] = (upscale_ops + facefix_ops) if upscale_first else (facefix_ops + upscale_ops)
|
||||
|
||||
for image, image_name in zip(imageArr, imageNameArr):
|
||||
if image is None:
|
||||
return outputs, "Please select an input image.", ''
|
||||
|
||||
shared.state.textinfo = f'Processing image {image_name}'
|
||||
|
||||
existing_pnginfo = image.info or {}
|
||||
|
||||
image = image.convert("RGB")
|
||||
info = ""
|
||||
# Run each operation on each image
|
||||
for op in extras_ops:
|
||||
image, info = op(image, info)
|
||||
|
||||
if opts.use_original_name_batch and image_name is not None:
|
||||
basename = os.path.splitext(os.path.basename(image_name))[0]
|
||||
else:
|
||||
basename = ''
|
||||
|
||||
if opts.enable_pnginfo: # append info before save
|
||||
image.info = existing_pnginfo
|
||||
image.info["extras"] = info
|
||||
|
||||
if save_output:
|
||||
# Add upscaler name as a suffix.
|
||||
suffix = f"-{shared.sd_upscalers[extras_upscaler_1].name}" if shared.opts.use_upscaler_name_as_suffix else ""
|
||||
# Add second upscaler if applicable.
|
||||
if suffix and extras_upscaler_2 and extras_upscaler_2_visibility:
|
||||
suffix += f"-{shared.sd_upscalers[extras_upscaler_2].name}"
|
||||
|
||||
images.save_image(image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True,
|
||||
no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None, suffix=suffix)
|
||||
|
||||
if extras_mode != 2 or show_extras_results :
|
||||
outputs.append(image)
|
||||
|
||||
devices.torch_gc()
|
||||
|
||||
return outputs, plaintext_to_html(info), ''
|
||||
|
||||
def clear_cache():
|
||||
cached_images.clear()
|
||||
|
||||
|
||||
def run_pnginfo(image):
|
||||
if image is None:
|
||||
|
@ -37,8 +251,7 @@ def run_pnginfo(image):
|
|||
|
||||
def create_config(ckpt_result, config_source, a, b, c):
|
||||
def config(x):
|
||||
res = sd_models_config.find_checkpoint_config_near_filename(x) if x else None
|
||||
return res if res != shared.sd_default_config else None
|
||||
return sd_models.find_checkpoint_config(x) if x else None
|
||||
|
||||
if config_source == 0:
|
||||
cfg = config(a) or config(b) or config(c)
|
||||
|
@ -61,25 +274,10 @@ def create_config(ckpt_result, config_source, a, b, c):
|
|||
shutil.copyfile(cfg, checkpoint_filename)
|
||||
|
||||
|
||||
checkpoint_dict_skip_on_merge = ["cond_stage_model.transformer.text_model.embeddings.position_ids"]
|
||||
|
||||
|
||||
def to_half(tensor, enable):
|
||||
if enable and tensor.dtype == torch.float:
|
||||
return tensor.half()
|
||||
|
||||
return tensor
|
||||
|
||||
|
||||
def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights):
|
||||
def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source):
|
||||
shared.state.begin()
|
||||
shared.state.job = 'model-merge'
|
||||
|
||||
def fail(message):
|
||||
shared.state.textinfo = message
|
||||
shared.state.end()
|
||||
return [*[gr.update() for _ in range(4)], message]
|
||||
|
||||
def weighted_sum(theta0, theta1, alpha):
|
||||
return ((1 - alpha) * theta0) + (alpha * theta1)
|
||||
|
||||
|
@ -89,156 +287,101 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
|
|||
def add_difference(theta0, theta1_2_diff, alpha):
|
||||
return theta0 + (alpha * theta1_2_diff)
|
||||
|
||||
def filename_weighted_sum():
|
||||
a = primary_model_info.model_name
|
||||
b = secondary_model_info.model_name
|
||||
Ma = round(1 - multiplier, 2)
|
||||
Mb = round(multiplier, 2)
|
||||
|
||||
return f"{Ma}({a}) + {Mb}({b})"
|
||||
|
||||
def filename_add_difference():
|
||||
a = primary_model_info.model_name
|
||||
b = secondary_model_info.model_name
|
||||
c = tertiary_model_info.model_name
|
||||
M = round(multiplier, 2)
|
||||
|
||||
return f"{a} + {M}({b} - {c})"
|
||||
|
||||
def filename_nothing():
|
||||
return primary_model_info.model_name
|
||||
primary_model_info = sd_models.checkpoints_list[primary_model_name]
|
||||
secondary_model_info = sd_models.checkpoints_list[secondary_model_name]
|
||||
tertiary_model_info = sd_models.checkpoints_list.get(tertiary_model_name, None)
|
||||
result_is_inpainting_model = False
|
||||
|
||||
theta_funcs = {
|
||||
"Weighted sum": (filename_weighted_sum, None, weighted_sum),
|
||||
"Add difference": (filename_add_difference, get_difference, add_difference),
|
||||
"No interpolation": (filename_nothing, None, None),
|
||||
"Weighted sum": (None, weighted_sum),
|
||||
"Add difference": (get_difference, add_difference),
|
||||
}
|
||||
filename_generator, theta_func1, theta_func2 = theta_funcs[interp_method]
|
||||
shared.state.job_count = (1 if theta_func1 else 0) + (1 if theta_func2 else 0)
|
||||
theta_func1, theta_func2 = theta_funcs[interp_method]
|
||||
|
||||
if not primary_model_name:
|
||||
return fail("Failed: Merging requires a primary model.")
|
||||
if theta_func1 and not tertiary_model_info:
|
||||
shared.state.textinfo = "Failed: Interpolation method requires a tertiary model."
|
||||
shared.state.end()
|
||||
return ["Failed: Interpolation method requires a tertiary model."] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)]
|
||||
|
||||
primary_model_info = sd_models.checkpoints_list[primary_model_name]
|
||||
|
||||
if theta_func2 and not secondary_model_name:
|
||||
return fail("Failed: Merging requires a secondary model.")
|
||||
|
||||
secondary_model_info = sd_models.checkpoints_list[secondary_model_name] if theta_func2 else None
|
||||
|
||||
if theta_func1 and not tertiary_model_name:
|
||||
return fail(f"Failed: Interpolation method ({interp_method}) requires a tertiary model.")
|
||||
|
||||
tertiary_model_info = sd_models.checkpoints_list[tertiary_model_name] if theta_func1 else None
|
||||
|
||||
result_is_inpainting_model = False
|
||||
result_is_instruct_pix2pix_model = False
|
||||
|
||||
if theta_func2:
|
||||
shared.state.textinfo = f"Loading B"
|
||||
print(f"Loading {secondary_model_info.filename}...")
|
||||
theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu')
|
||||
else:
|
||||
theta_1 = None
|
||||
shared.state.textinfo = f"Loading {secondary_model_info.filename}..."
|
||||
print(f"Loading {secondary_model_info.filename}...")
|
||||
theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu')
|
||||
|
||||
if theta_func1:
|
||||
shared.state.textinfo = f"Loading C"
|
||||
print(f"Loading {tertiary_model_info.filename}...")
|
||||
theta_2 = sd_models.read_state_dict(tertiary_model_info.filename, map_location='cpu')
|
||||
|
||||
shared.state.textinfo = 'Merging B and C'
|
||||
shared.state.sampling_steps = len(theta_1.keys())
|
||||
for key in tqdm.tqdm(theta_1.keys()):
|
||||
if key in checkpoint_dict_skip_on_merge:
|
||||
continue
|
||||
|
||||
if 'model' in key:
|
||||
if key in theta_2:
|
||||
t2 = theta_2.get(key, torch.zeros_like(theta_1[key]))
|
||||
theta_1[key] = theta_func1(theta_1[key], t2)
|
||||
else:
|
||||
theta_1[key] = torch.zeros_like(theta_1[key])
|
||||
|
||||
shared.state.sampling_step += 1
|
||||
del theta_2
|
||||
|
||||
shared.state.nextjob()
|
||||
|
||||
shared.state.textinfo = f"Loading {primary_model_info.filename}..."
|
||||
print(f"Loading {primary_model_info.filename}...")
|
||||
theta_0 = sd_models.read_state_dict(primary_model_info.filename, map_location='cpu')
|
||||
|
||||
print("Merging...")
|
||||
shared.state.textinfo = 'Merging A and B'
|
||||
shared.state.sampling_steps = len(theta_0.keys())
|
||||
for key in tqdm.tqdm(theta_0.keys()):
|
||||
if theta_1 and 'model' in key and key in theta_1:
|
||||
|
||||
if key in checkpoint_dict_skip_on_merge:
|
||||
chckpoint_dict_skip_on_merge = ["cond_stage_model.transformer.text_model.embeddings.position_ids"]
|
||||
|
||||
for key in tqdm.tqdm(theta_0.keys()):
|
||||
if 'model' in key and key in theta_1:
|
||||
|
||||
if key in chckpoint_dict_skip_on_merge:
|
||||
continue
|
||||
|
||||
a = theta_0[key]
|
||||
b = theta_1[key]
|
||||
|
||||
shared.state.textinfo = f'Merging layer {key}'
|
||||
# this enables merging an inpainting model (A) with another one (B);
|
||||
# where normal model would have 4 channels, for latenst space, inpainting model would
|
||||
# have another 4 channels for unmasked picture's latent space, plus one channel for mask, for a total of 9
|
||||
if a.shape != b.shape and a.shape[0:1] + a.shape[2:] == b.shape[0:1] + b.shape[2:]:
|
||||
if a.shape[1] == 4 and b.shape[1] == 9:
|
||||
raise RuntimeError("When merging inpainting model with a normal one, A must be the inpainting model.")
|
||||
if a.shape[1] == 4 and b.shape[1] == 8:
|
||||
raise RuntimeError("When merging instruct-pix2pix model with a normal one, A must be the instruct-pix2pix model.")
|
||||
|
||||
if a.shape[1] == 8 and b.shape[1] == 4:#If we have an Instruct-Pix2Pix model...
|
||||
theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier)#Merge only the vectors the models have in common. Otherwise we get an error due to dimension mismatch.
|
||||
result_is_instruct_pix2pix_model = True
|
||||
else:
|
||||
assert a.shape[1] == 9 and b.shape[1] == 4, f"Bad dimensions for merged layer {key}: A={a.shape}, B={b.shape}"
|
||||
theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier)
|
||||
result_is_inpainting_model = True
|
||||
assert a.shape[1] == 9 and b.shape[1] == 4, f"Bad dimensions for merged layer {key}: A={a.shape}, B={b.shape}"
|
||||
|
||||
theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier)
|
||||
result_is_inpainting_model = True
|
||||
else:
|
||||
theta_0[key] = theta_func2(a, b, multiplier)
|
||||
|
||||
theta_0[key] = to_half(theta_0[key], save_as_half)
|
||||
|
||||
shared.state.sampling_step += 1
|
||||
if save_as_half:
|
||||
theta_0[key] = theta_0[key].half()
|
||||
|
||||
# I believe this part should be discarded, but I'll leave it for now until I am sure
|
||||
for key in theta_1.keys():
|
||||
if 'model' in key and key not in theta_0:
|
||||
|
||||
if key in chckpoint_dict_skip_on_merge:
|
||||
continue
|
||||
|
||||
theta_0[key] = theta_1[key]
|
||||
if save_as_half:
|
||||
theta_0[key] = theta_0[key].half()
|
||||
del theta_1
|
||||
|
||||
bake_in_vae_filename = sd_vae.vae_dict.get(bake_in_vae, None)
|
||||
if bake_in_vae_filename is not None:
|
||||
print(f"Baking in VAE from {bake_in_vae_filename}")
|
||||
shared.state.textinfo = 'Baking in VAE'
|
||||
vae_dict = sd_vae.load_vae_dict(bake_in_vae_filename, map_location='cpu')
|
||||
|
||||
for key in vae_dict.keys():
|
||||
theta_0_key = 'first_stage_model.' + key
|
||||
if theta_0_key in theta_0:
|
||||
theta_0[theta_0_key] = to_half(vae_dict[key], save_as_half)
|
||||
|
||||
del vae_dict
|
||||
|
||||
if save_as_half and not theta_func2:
|
||||
for key in theta_0.keys():
|
||||
theta_0[key] = to_half(theta_0[key], save_as_half)
|
||||
|
||||
if discard_weights:
|
||||
regex = re.compile(discard_weights)
|
||||
for key in list(theta_0):
|
||||
if re.search(regex, key):
|
||||
theta_0.pop(key, None)
|
||||
|
||||
ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path
|
||||
|
||||
filename = filename_generator() if custom_name == '' else custom_name
|
||||
filename += ".inpainting" if result_is_inpainting_model else ""
|
||||
filename += ".instruct-pix2pix" if result_is_instruct_pix2pix_model else ""
|
||||
filename += "." + checkpoint_format
|
||||
filename = \
|
||||
primary_model_info.model_name + '_' + str(round(1-multiplier, 2)) + '-' + \
|
||||
secondary_model_info.model_name + '_' + str(round(multiplier, 2)) + '-' + \
|
||||
interp_method.replace(" ", "_") + \
|
||||
'-merged.' + \
|
||||
("inpainting." if result_is_inpainting_model else "") + \
|
||||
checkpoint_format
|
||||
|
||||
filename = filename if custom_name == '' else (custom_name + '.' + checkpoint_format)
|
||||
|
||||
output_modelname = os.path.join(ckpt_dir, filename)
|
||||
|
||||
shared.state.nextjob()
|
||||
shared.state.textinfo = "Saving"
|
||||
shared.state.textinfo = f"Saving to {output_modelname}..."
|
||||
print(f"Saving to {output_modelname}...")
|
||||
|
||||
_, extension = os.path.splitext(output_modelname)
|
||||
|
@ -251,8 +394,8 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
|
|||
|
||||
create_config(output_modelname, config_source, primary_model_info, secondary_model_info, tertiary_model_info)
|
||||
|
||||
print(f"Checkpoint saved to {output_modelname}.")
|
||||
shared.state.textinfo = "Checkpoint saved"
|
||||
print("Checkpoint saved.")
|
||||
shared.state.textinfo = "Checkpoint saved to " + output_modelname
|
||||
shared.state.end()
|
||||
|
||||
return [*[gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)], "Checkpoint saved to " + output_modelname]
|
||||
return ["Checkpoint saved to " + output_modelname] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)]
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
import base64
|
||||
import html
|
||||
import io
|
||||
import math
|
||||
import os
|
||||
|
@ -7,33 +6,24 @@ import re
|
|||
from pathlib import Path
|
||||
|
||||
import gradio as gr
|
||||
from modules.paths import data_path
|
||||
from modules.shared import script_path
|
||||
from modules import shared, ui_tempdir, script_callbacks
|
||||
import tempfile
|
||||
from PIL import Image
|
||||
|
||||
re_param_code = r'\s*([\w ]+):\s*("(?:\\"[^,]|\\"|\\|[^\"])+"|[^,]*)(?:,|$)'
|
||||
re_param_code = r'\s*([\w ]+):\s*("(?:\\|\"|[^\"])+"|[^,]*)(?:,|$)'
|
||||
re_param = re.compile(re_param_code)
|
||||
re_params = re.compile(r"^(?:" + re_param_code + "){3,}$")
|
||||
re_imagesize = re.compile(r"^(\d+)x(\d+)$")
|
||||
re_hypernet_hash = re.compile("\(([0-9a-f]+)\)$")
|
||||
type_of_gr_update = type(gr.update())
|
||||
|
||||
paste_fields = {}
|
||||
registered_param_bindings = []
|
||||
|
||||
|
||||
class ParamBinding:
|
||||
def __init__(self, paste_button, tabname, source_text_component=None, source_image_component=None, source_tabname=None, override_settings_component=None):
|
||||
self.paste_button = paste_button
|
||||
self.tabname = tabname
|
||||
self.source_text_component = source_text_component
|
||||
self.source_image_component = source_image_component
|
||||
self.source_tabname = source_tabname
|
||||
self.override_settings_component = override_settings_component
|
||||
bind_list = []
|
||||
|
||||
|
||||
def reset():
|
||||
paste_fields.clear()
|
||||
bind_list.clear()
|
||||
|
||||
|
||||
def quote(text):
|
||||
|
@ -47,9 +37,6 @@ def quote(text):
|
|||
|
||||
|
||||
def image_from_url_text(filedata):
|
||||
if filedata is None:
|
||||
return None
|
||||
|
||||
if type(filedata) == list and len(filedata) > 0 and type(filedata[0]) == dict and filedata[0].get("is_file", False):
|
||||
filedata = filedata[0]
|
||||
|
||||
|
@ -85,6 +72,28 @@ def add_paste_fields(tabname, init_img, fields):
|
|||
modules.ui.img2img_paste_fields = fields
|
||||
|
||||
|
||||
def integrate_settings_paste_fields(component_dict):
|
||||
from modules import ui
|
||||
|
||||
settings_map = {
|
||||
'sd_hypernetwork': 'Hypernet',
|
||||
'sd_hypernetwork_strength': 'Hypernet strength',
|
||||
'CLIP_stop_at_last_layers': 'Clip skip',
|
||||
'inpainting_mask_weight': 'Conditional mask weight',
|
||||
'sd_model_checkpoint': 'Model hash',
|
||||
'eta_noise_seed_delta': 'ENSD',
|
||||
'initial_noise_multiplier': 'Noise multiplier',
|
||||
}
|
||||
settings_paste_fields = [
|
||||
(component_dict[k], lambda d, k=k, v=v: ui.apply_setting(k, d.get(v, None)))
|
||||
for k, v in settings_map.items()
|
||||
]
|
||||
|
||||
for tabname, info in paste_fields.items():
|
||||
if info["fields"] is not None:
|
||||
info["fields"] += settings_paste_fields
|
||||
|
||||
|
||||
def create_buttons(tabs_list):
|
||||
buttons = {}
|
||||
for tab in tabs_list:
|
||||
|
@ -92,60 +101,9 @@ def create_buttons(tabs_list):
|
|||
return buttons
|
||||
|
||||
|
||||
#if send_generate_info is a tab name, mean generate_info comes from the params fields of the tab
|
||||
def bind_buttons(buttons, send_image, send_generate_info):
|
||||
"""old function for backwards compatibility; do not use this, use register_paste_params_button"""
|
||||
for tabname, button in buttons.items():
|
||||
source_text_component = send_generate_info if isinstance(send_generate_info, gr.components.Component) else None
|
||||
source_tabname = send_generate_info if isinstance(send_generate_info, str) else None
|
||||
|
||||
register_paste_params_button(ParamBinding(paste_button=button, tabname=tabname, source_text_component=source_text_component, source_image_component=send_image, source_tabname=source_tabname))
|
||||
|
||||
|
||||
def register_paste_params_button(binding: ParamBinding):
|
||||
registered_param_bindings.append(binding)
|
||||
|
||||
|
||||
def connect_paste_params_buttons():
|
||||
binding: ParamBinding
|
||||
for binding in registered_param_bindings:
|
||||
destination_image_component = paste_fields[binding.tabname]["init_img"]
|
||||
fields = paste_fields[binding.tabname]["fields"]
|
||||
|
||||
destination_width_component = next(iter([field for field, name in fields if name == "Size-1"] if fields else []), None)
|
||||
destination_height_component = next(iter([field for field, name in fields if name == "Size-2"] if fields else []), None)
|
||||
|
||||
if binding.source_image_component and destination_image_component:
|
||||
if isinstance(binding.source_image_component, gr.Gallery):
|
||||
func = send_image_and_dimensions if destination_width_component else image_from_url_text
|
||||
jsfunc = "extract_image_from_gallery"
|
||||
else:
|
||||
func = send_image_and_dimensions if destination_width_component else lambda x: x
|
||||
jsfunc = None
|
||||
|
||||
binding.paste_button.click(
|
||||
fn=func,
|
||||
_js=jsfunc,
|
||||
inputs=[binding.source_image_component],
|
||||
outputs=[destination_image_component, destination_width_component, destination_height_component] if destination_width_component else [destination_image_component],
|
||||
)
|
||||
|
||||
if binding.source_text_component is not None and fields is not None:
|
||||
connect_paste(binding.paste_button, fields, binding.source_text_component, binding.override_settings_component, binding.tabname)
|
||||
|
||||
if binding.source_tabname is not None and fields is not None:
|
||||
paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration'] + (["Seed"] if shared.opts.send_seed else [])
|
||||
binding.paste_button.click(
|
||||
fn=lambda *x: x,
|
||||
inputs=[field for field, name in paste_fields[binding.source_tabname]["fields"] if name in paste_field_names],
|
||||
outputs=[field for field, name in fields if name in paste_field_names],
|
||||
)
|
||||
|
||||
binding.paste_button.click(
|
||||
fn=None,
|
||||
_js=f"switch_to_{binding.tabname}",
|
||||
inputs=None,
|
||||
outputs=None,
|
||||
)
|
||||
bind_list.append([buttons, send_image, send_generate_info])
|
||||
|
||||
|
||||
def send_image_and_dimensions(x):
|
||||
|
@ -164,6 +122,49 @@ def send_image_and_dimensions(x):
|
|||
return img, w, h
|
||||
|
||||
|
||||
def run_bind():
|
||||
for buttons, source_image_component, send_generate_info in bind_list:
|
||||
for tab in buttons:
|
||||
button = buttons[tab]
|
||||
destination_image_component = paste_fields[tab]["init_img"]
|
||||
fields = paste_fields[tab]["fields"]
|
||||
|
||||
destination_width_component = next(iter([field for field, name in fields if name == "Size-1"] if fields else []), None)
|
||||
destination_height_component = next(iter([field for field, name in fields if name == "Size-2"] if fields else []), None)
|
||||
|
||||
if source_image_component and destination_image_component:
|
||||
if isinstance(source_image_component, gr.Gallery):
|
||||
func = send_image_and_dimensions if destination_width_component else image_from_url_text
|
||||
jsfunc = "extract_image_from_gallery"
|
||||
else:
|
||||
func = send_image_and_dimensions if destination_width_component else lambda x: x
|
||||
jsfunc = None
|
||||
|
||||
button.click(
|
||||
fn=func,
|
||||
_js=jsfunc,
|
||||
inputs=[source_image_component],
|
||||
outputs=[destination_image_component, destination_width_component, destination_height_component] if destination_width_component else [destination_image_component],
|
||||
)
|
||||
|
||||
if send_generate_info and fields is not None:
|
||||
if send_generate_info in paste_fields:
|
||||
paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration'] + (["Seed"] if shared.opts.send_seed else [])
|
||||
button.click(
|
||||
fn=lambda *x: x,
|
||||
inputs=[field for field, name in paste_fields[send_generate_info]["fields"] if name in paste_field_names],
|
||||
outputs=[field for field, name in fields if name in paste_field_names],
|
||||
)
|
||||
else:
|
||||
connect_paste(button, fields, send_generate_info)
|
||||
|
||||
button.click(
|
||||
fn=None,
|
||||
_js=f"switch_to_{tab}",
|
||||
inputs=None,
|
||||
outputs=None,
|
||||
)
|
||||
|
||||
|
||||
def find_hypernetwork_key(hypernet_name, hypernet_hash=None):
|
||||
"""Determines the config parameter name to use for the hypernet based on the parameters in the infotext.
|
||||
|
@ -241,7 +242,7 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
|
|||
done_with_prompt = False
|
||||
|
||||
*lines, lastline = x.strip().split("\n")
|
||||
if len(re_param.findall(lastline)) < 3:
|
||||
if not re_params.match(lastline):
|
||||
lines.append(lastline)
|
||||
lastline = ''
|
||||
|
||||
|
@ -260,7 +261,6 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
|
|||
res["Negative prompt"] = negative_prompt
|
||||
|
||||
for k, v in re_param.findall(lastline):
|
||||
v = v[1:-1] if v[0] == '"' and v[-1] == '"' else v
|
||||
m = re_imagesize.match(v)
|
||||
if m is not None:
|
||||
res[k+"-1"] = m.group(1)
|
||||
|
@ -272,9 +272,13 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
|
|||
if "Clip skip" not in res:
|
||||
res["Clip skip"] = "1"
|
||||
|
||||
hypernet = res.get("Hypernet", None)
|
||||
if hypernet is not None:
|
||||
res["Prompt"] += f"""<hypernet:{hypernet}:{res.get("Hypernet strength", "1.0")}>"""
|
||||
if "Hypernet strength" not in res:
|
||||
res["Hypernet strength"] = "1"
|
||||
|
||||
if "Hypernet" in res:
|
||||
hypernet_name = res["Hypernet"]
|
||||
hypernet_hash = res.get("Hypernet hash", None)
|
||||
res["Hypernet"] = find_hypernetwork_key(hypernet_name, hypernet_hash)
|
||||
|
||||
if "Hires resize-1" not in res:
|
||||
res["Hires resize-1"] = 0
|
||||
|
@ -285,53 +289,10 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
|
|||
return res
|
||||
|
||||
|
||||
settings_map = {}
|
||||
|
||||
infotext_to_setting_name_mapping = [
|
||||
('Clip skip', 'CLIP_stop_at_last_layers', ),
|
||||
('Conditional mask weight', 'inpainting_mask_weight'),
|
||||
('Model hash', 'sd_model_checkpoint'),
|
||||
('ENSD', 'eta_noise_seed_delta'),
|
||||
('Noise multiplier', 'initial_noise_multiplier'),
|
||||
('Eta', 'eta_ancestral'),
|
||||
('Eta DDIM', 'eta_ddim'),
|
||||
('Discard penultimate sigma', 'always_discard_next_to_last_sigma')
|
||||
]
|
||||
|
||||
|
||||
def create_override_settings_dict(text_pairs):
|
||||
"""creates processing's override_settings parameters from gradio's multiselect
|
||||
|
||||
Example input:
|
||||
['Clip skip: 2', 'Model hash: e6e99610c4', 'ENSD: 31337']
|
||||
|
||||
Example output:
|
||||
{'CLIP_stop_at_last_layers': 2, 'sd_model_checkpoint': 'e6e99610c4', 'eta_noise_seed_delta': 31337}
|
||||
"""
|
||||
|
||||
res = {}
|
||||
|
||||
params = {}
|
||||
for pair in text_pairs:
|
||||
k, v = pair.split(":", maxsplit=1)
|
||||
|
||||
params[k] = v.strip()
|
||||
|
||||
for param_name, setting_name in infotext_to_setting_name_mapping:
|
||||
value = params.get(param_name, None)
|
||||
|
||||
if value is None:
|
||||
continue
|
||||
|
||||
res[setting_name] = shared.opts.cast_value(setting_name, value)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def connect_paste(button, paste_fields, input_comp, override_settings_component, tabname):
|
||||
def connect_paste(button, paste_fields, input_comp, jsfunc=None):
|
||||
def paste_func(prompt):
|
||||
if not prompt and not shared.cmd_opts.hide_ui_dir_config:
|
||||
filename = os.path.join(data_path, "params.txt")
|
||||
filename = os.path.join(script_path, "params.txt")
|
||||
if os.path.exists(filename):
|
||||
with open(filename, "r", encoding="utf8") as file:
|
||||
prompt = file.read()
|
||||
|
@ -365,35 +326,9 @@ def connect_paste(button, paste_fields, input_comp, override_settings_component,
|
|||
|
||||
return res
|
||||
|
||||
if override_settings_component is not None:
|
||||
def paste_settings(params):
|
||||
vals = {}
|
||||
|
||||
for param_name, setting_name in infotext_to_setting_name_mapping:
|
||||
v = params.get(param_name, None)
|
||||
if v is None:
|
||||
continue
|
||||
|
||||
if setting_name == "sd_model_checkpoint" and shared.opts.disable_weights_auto_swap:
|
||||
continue
|
||||
|
||||
v = shared.opts.cast_value(setting_name, v)
|
||||
current_value = getattr(shared.opts, setting_name, None)
|
||||
|
||||
if v == current_value:
|
||||
continue
|
||||
|
||||
vals[param_name] = v
|
||||
|
||||
vals_pairs = [f"{k}: {v}" for k, v in vals.items()]
|
||||
|
||||
return gr.Dropdown.update(value=vals_pairs, choices=vals_pairs, visible=len(vals_pairs) > 0)
|
||||
|
||||
paste_fields = paste_fields + [(override_settings_component, paste_settings)]
|
||||
|
||||
button.click(
|
||||
fn=paste_func,
|
||||
_js=f"recalculate_prompts_{tabname}",
|
||||
_js=jsfunc,
|
||||
inputs=[input_comp],
|
||||
outputs=[x[0] for x in paste_fields],
|
||||
)
|
||||
|
|
|
@ -6,11 +6,12 @@ import facexlib
|
|||
import gfpgan
|
||||
|
||||
import modules.face_restoration
|
||||
from modules import paths, shared, devices, modelloader
|
||||
from modules import shared, devices, modelloader
|
||||
from modules.paths import models_path
|
||||
|
||||
model_dir = "GFPGAN"
|
||||
user_path = None
|
||||
model_path = os.path.join(paths.models_path, model_dir)
|
||||
model_path = os.path.join(models_path, model_dir)
|
||||
model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
|
||||
have_gfpgan = False
|
||||
loaded_gfpgan_model = None
|
||||
|
|
|
@ -4,11 +4,8 @@ import os.path
|
|||
|
||||
import filelock
|
||||
|
||||
from modules import shared
|
||||
from modules.paths import data_path
|
||||
|
||||
|
||||
cache_filename = os.path.join(data_path, "cache.json")
|
||||
cache_filename = "cache.json"
|
||||
cache_data = None
|
||||
|
||||
|
||||
|
@ -69,9 +66,6 @@ def sha256(filename, title):
|
|||
if sha256_value is not None:
|
||||
return sha256_value
|
||||
|
||||
if shared.cmd_opts.no_hashing:
|
||||
return None
|
||||
|
||||
print(f"Calculating sha256 for {filename}: ", end='')
|
||||
sha256_value = calculate_sha256(filename)
|
||||
print(f"{sha256_value}")
|
||||
|
|
|
@ -12,7 +12,7 @@ import torch
|
|||
import tqdm
|
||||
from einops import rearrange, repeat
|
||||
from ldm.util import default
|
||||
from modules import devices, processing, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint
|
||||
from modules import devices, processing, sd_models, shared, sd_samplers, hashes
|
||||
from modules.textual_inversion import textual_inversion, logging
|
||||
from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
||||
from torch import einsum
|
||||
|
@ -25,6 +25,7 @@ from statistics import stdev, mean
|
|||
optimizer_dict = {optim_name : cls_obj for optim_name, cls_obj in inspect.getmembers(torch.optim, inspect.isclass) if optim_name != "Optimizer"}
|
||||
|
||||
class HypernetworkModule(torch.nn.Module):
|
||||
multiplier = 1.0
|
||||
activation_dict = {
|
||||
"linear": torch.nn.Identity,
|
||||
"relu": torch.nn.ReLU,
|
||||
|
@ -40,8 +41,6 @@ class HypernetworkModule(torch.nn.Module):
|
|||
add_layer_norm=False, activate_output=False, dropout_structure=None):
|
||||
super().__init__()
|
||||
|
||||
self.multiplier = 1.0
|
||||
|
||||
assert layer_structure is not None, "layer_structure must not be None"
|
||||
assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
|
||||
assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
|
||||
|
@ -116,7 +115,7 @@ class HypernetworkModule(torch.nn.Module):
|
|||
state_dict[to] = x
|
||||
|
||||
def forward(self, x):
|
||||
return x + self.linear(x) * (self.multiplier if not self.training else 1)
|
||||
return x + self.linear(x) * (HypernetworkModule.multiplier if not self.training else 1)
|
||||
|
||||
def trainables(self):
|
||||
layer_structure = []
|
||||
|
@ -126,6 +125,9 @@ class HypernetworkModule(torch.nn.Module):
|
|||
return layer_structure
|
||||
|
||||
|
||||
def apply_strength(value=None):
|
||||
HypernetworkModule.multiplier = value if value is not None else shared.opts.sd_hypernetwork_strength
|
||||
|
||||
#param layer_structure : sequence used for length, use_dropout : controlling boolean, last_layer_dropout : for compatibility check.
|
||||
def parse_dropout_structure(layer_structure, use_dropout, last_layer_dropout):
|
||||
if layer_structure is None:
|
||||
|
@ -190,20 +192,6 @@ class Hypernetwork:
|
|||
for param in layer.parameters():
|
||||
param.requires_grad = mode
|
||||
|
||||
def to(self, device):
|
||||
for k, layers in self.layers.items():
|
||||
for layer in layers:
|
||||
layer.to(device)
|
||||
|
||||
return self
|
||||
|
||||
def set_multiplier(self, multiplier):
|
||||
for k, layers in self.layers.items():
|
||||
for layer in layers:
|
||||
layer.multiplier = multiplier
|
||||
|
||||
return self
|
||||
|
||||
def eval(self):
|
||||
for k, layers in self.layers.items():
|
||||
for layer in layers:
|
||||
|
@ -281,13 +269,11 @@ class Hypernetwork:
|
|||
self.optimizer_state_dict = None
|
||||
if self.optimizer_state_dict:
|
||||
self.optimizer_name = optimizer_saved_dict.get('optimizer_name', 'AdamW')
|
||||
if shared.opts.print_hypernet_extra:
|
||||
print("Loaded existing optimizer from checkpoint")
|
||||
print(f"Optimizer name is {self.optimizer_name}")
|
||||
print("Loaded existing optimizer from checkpoint")
|
||||
print(f"Optimizer name is {self.optimizer_name}")
|
||||
else:
|
||||
self.optimizer_name = "AdamW"
|
||||
if shared.opts.print_hypernet_extra:
|
||||
print("No saved optimizer exists in checkpoint")
|
||||
print("No saved optimizer exists in checkpoint")
|
||||
|
||||
for size, sd in state_dict.items():
|
||||
if type(size) == int:
|
||||
|
@ -307,7 +293,7 @@ class Hypernetwork:
|
|||
def shorthash(self):
|
||||
sha256 = hashes.sha256(self.filename, f'hypernet/{self.name}')
|
||||
|
||||
return sha256[0:10] if sha256 else None
|
||||
return sha256[0:10]
|
||||
|
||||
|
||||
def list_hypernetworks(path):
|
||||
|
@ -320,43 +306,23 @@ def list_hypernetworks(path):
|
|||
return res
|
||||
|
||||
|
||||
def load_hypernetwork(name):
|
||||
path = shared.hypernetworks.get(name, None)
|
||||
def load_hypernetwork(filename):
|
||||
path = shared.hypernetworks.get(filename, None)
|
||||
# Prevent any file named "None.pt" from being loaded.
|
||||
if path is not None and filename != "None":
|
||||
print(f"Loading hypernetwork {filename}")
|
||||
try:
|
||||
shared.loaded_hypernetwork = Hypernetwork()
|
||||
shared.loaded_hypernetwork.load(path)
|
||||
|
||||
if path is None:
|
||||
return None
|
||||
except Exception:
|
||||
print(f"Error loading hypernetwork {path}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
else:
|
||||
if shared.loaded_hypernetwork is not None:
|
||||
print("Unloading hypernetwork")
|
||||
|
||||
hypernetwork = Hypernetwork()
|
||||
|
||||
try:
|
||||
hypernetwork.load(path)
|
||||
except Exception:
|
||||
print(f"Error loading hypernetwork {path}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
return None
|
||||
|
||||
return hypernetwork
|
||||
|
||||
|
||||
def load_hypernetworks(names, multipliers=None):
|
||||
already_loaded = {}
|
||||
|
||||
for hypernetwork in shared.loaded_hypernetworks:
|
||||
if hypernetwork.name in names:
|
||||
already_loaded[hypernetwork.name] = hypernetwork
|
||||
|
||||
shared.loaded_hypernetworks.clear()
|
||||
|
||||
for i, name in enumerate(names):
|
||||
hypernetwork = already_loaded.get(name, None)
|
||||
if hypernetwork is None:
|
||||
hypernetwork = load_hypernetwork(name)
|
||||
|
||||
if hypernetwork is None:
|
||||
continue
|
||||
|
||||
hypernetwork.set_multiplier(multipliers[i] if multipliers else 1.0)
|
||||
shared.loaded_hypernetworks.append(hypernetwork)
|
||||
shared.loaded_hypernetwork = None
|
||||
|
||||
|
||||
def find_closest_hypernetwork_name(search: str):
|
||||
|
@ -370,27 +336,18 @@ def find_closest_hypernetwork_name(search: str):
|
|||
return applicable[0]
|
||||
|
||||
|
||||
def apply_single_hypernetwork(hypernetwork, context_k, context_v, layer=None):
|
||||
hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context_k.shape[2], None)
|
||||
def apply_hypernetwork(hypernetwork, context, layer=None):
|
||||
hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
|
||||
|
||||
if hypernetwork_layers is None:
|
||||
return context_k, context_v
|
||||
return context, context
|
||||
|
||||
if layer is not None:
|
||||
layer.hyper_k = hypernetwork_layers[0]
|
||||
layer.hyper_v = hypernetwork_layers[1]
|
||||
|
||||
context_k = hypernetwork_layers[0](context_k)
|
||||
context_v = hypernetwork_layers[1](context_v)
|
||||
return context_k, context_v
|
||||
|
||||
|
||||
def apply_hypernetworks(hypernetworks, context, layer=None):
|
||||
context_k = context
|
||||
context_v = context
|
||||
for hypernetwork in hypernetworks:
|
||||
context_k, context_v = apply_single_hypernetwork(hypernetwork, context_k, context_v, layer)
|
||||
|
||||
context_k = hypernetwork_layers[0](context)
|
||||
context_v = hypernetwork_layers[1](context)
|
||||
return context_k, context_v
|
||||
|
||||
|
||||
|
@ -400,7 +357,7 @@ def attention_CrossAttention_forward(self, x, context=None, mask=None):
|
|||
q = self.to_q(x)
|
||||
context = default(context, x)
|
||||
|
||||
context_k, context_v = apply_hypernetworks(shared.loaded_hypernetworks, context, self)
|
||||
context_k, context_v = apply_hypernetwork(shared.loaded_hypernetwork, context, self)
|
||||
k = self.to_k(context_k)
|
||||
v = self.to_v(context_v)
|
||||
|
||||
|
@ -507,9 +464,8 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
|||
template_file = template_file.path
|
||||
|
||||
path = shared.hypernetworks.get(hypernetwork_name, None)
|
||||
hypernetwork = Hypernetwork()
|
||||
hypernetwork.load(path)
|
||||
shared.loaded_hypernetworks = [hypernetwork]
|
||||
shared.loaded_hypernetwork = Hypernetwork()
|
||||
shared.loaded_hypernetwork.load(path)
|
||||
|
||||
shared.state.job = "train-hypernetwork"
|
||||
shared.state.textinfo = "Initializing hypernetwork training..."
|
||||
|
@ -533,6 +489,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
|||
else:
|
||||
images_dir = None
|
||||
|
||||
hypernetwork = shared.loaded_hypernetwork
|
||||
checkpoint = sd_models.select_checkpoint()
|
||||
|
||||
initial_step = hypernetwork.step or 0
|
||||
|
@ -604,7 +561,6 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
|||
_loss_step = 0 #internal
|
||||
# size = len(ds.indexes)
|
||||
# loss_dict = defaultdict(lambda : deque(maxlen = 1024))
|
||||
loss_logging = deque(maxlen=len(ds) * 3) # this should be configurable parameter, this is 3 * epoch(dataset size)
|
||||
# losses = torch.zeros((size,))
|
||||
# previous_mean_losses = [0]
|
||||
# previous_mean_loss = 0
|
||||
|
@ -618,8 +574,6 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
|||
|
||||
pbar = tqdm.tqdm(total=steps - initial_step)
|
||||
try:
|
||||
sd_hijack_checkpoint.add()
|
||||
|
||||
for i in range((steps-initial_step) * gradient_step):
|
||||
if scheduler.finished:
|
||||
break
|
||||
|
@ -656,7 +610,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
|||
# go back until we reach gradient accumulation steps
|
||||
if (j + 1) % gradient_step != 0:
|
||||
continue
|
||||
loss_logging.append(_loss_step)
|
||||
|
||||
if clip_grad:
|
||||
clip_grad(weights, clip_grad_sched.learn_rate)
|
||||
|
||||
|
@ -690,7 +644,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
|||
if shared.opts.training_enable_tensorboard:
|
||||
epoch_num = hypernetwork.step // len(ds)
|
||||
epoch_step = hypernetwork.step - (epoch_num * len(ds)) + 1
|
||||
mean_loss = sum(loss_logging) / len(loss_logging)
|
||||
mean_loss = sum(sum(x) for x in loss_dict.values()) / sum(len(x) for x in loss_dict.values())
|
||||
textual_inversion.tensorboard_add(tensorboard_writer, loss=mean_loss, global_step=hypernetwork.step, step=epoch_step, learn_rate=scheduler.learn_rate, epoch_num=epoch_num)
|
||||
|
||||
textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, steps_per_epoch, {
|
||||
|
@ -715,8 +669,6 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
|||
do_not_save_samples=True,
|
||||
)
|
||||
|
||||
p.disable_extra_networks = True
|
||||
|
||||
if preview_from_txt2img:
|
||||
p.prompt = preview_prompt
|
||||
p.negative_prompt = preview_negative_prompt
|
||||
|
@ -736,6 +688,9 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
|||
|
||||
processed = processing.process_images(p)
|
||||
image = processed.images[0] if len(processed.images) > 0 else None
|
||||
|
||||
if shared.opts.training_enable_tensorboard and shared.opts.training_tensorboard_save_images:
|
||||
textual_inversion.tensorboard_add_image(tensorboard_writer, f"Validation at epoch {epoch_num}", image, hypernetwork.step)
|
||||
|
||||
if unload:
|
||||
shared.sd_model.cond_stage_model.to(devices.cpu)
|
||||
|
@ -746,10 +701,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
|||
hypernetwork.train()
|
||||
if image is not None:
|
||||
shared.state.assign_current_image(image)
|
||||
if shared.opts.training_enable_tensorboard and shared.opts.training_tensorboard_save_images:
|
||||
textual_inversion.tensorboard_add_image(tensorboard_writer,
|
||||
f"Validation at epoch {epoch_num}", image,
|
||||
hypernetwork.step)
|
||||
|
||||
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
|
||||
last_saved_image += f", prompt: {preview_text}"
|
||||
|
||||
|
@ -771,9 +723,6 @@ Last saved image: {html.escape(last_saved_image)}<br/>
|
|||
pbar.close()
|
||||
hypernetwork.eval()
|
||||
#report_statistics(loss_dict)
|
||||
sd_hijack_checkpoint.remove()
|
||||
|
||||
|
||||
|
||||
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
|
||||
hypernetwork.optimizer_name = optimizer_name
|
||||
|
|
|
@ -9,7 +9,6 @@ from modules import devices, sd_hijack, shared
|
|||
not_available = ["hardswish", "multiheadattention"]
|
||||
keys = list(x for x in modules.hypernetworks.hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available)
|
||||
|
||||
|
||||
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None):
|
||||
filename = modules.hypernetworks.hypernetwork.create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout, dropout_structure)
|
||||
|
||||
|
@ -17,7 +16,8 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None,
|
|||
|
||||
|
||||
def train_hypernetwork(*args):
|
||||
shared.loaded_hypernetworks = []
|
||||
|
||||
initial_hypernetwork = shared.loaded_hypernetwork
|
||||
|
||||
assert not shared.cmd_opts.lowvram, 'Training models with lowvram is not possible'
|
||||
|
||||
|
@ -34,6 +34,7 @@ Hypernetwork saved to {html.escape(filename)}
|
|||
except Exception:
|
||||
raise
|
||||
finally:
|
||||
shared.loaded_hypernetwork = initial_hypernetwork
|
||||
shared.sd_model.cond_stage_model.to(devices.device)
|
||||
shared.sd_model.first_stage_model.to(devices.device)
|
||||
sd_hijack.apply_optimizations()
|
||||
|
|
|
@ -16,7 +16,6 @@ from PIL import Image, ImageFont, ImageDraw, PngImagePlugin
|
|||
from fonts.ttf import Roboto
|
||||
import string
|
||||
import json
|
||||
import hashlib
|
||||
|
||||
from modules import sd_samplers, shared, script_callbacks
|
||||
from modules.shared import opts, cmd_opts
|
||||
|
@ -37,8 +36,6 @@ def image_grid(imgs, batch_size=1, rows=None):
|
|||
else:
|
||||
rows = math.sqrt(len(imgs))
|
||||
rows = round(rows)
|
||||
if rows > len(imgs):
|
||||
rows = len(imgs)
|
||||
|
||||
cols = math.ceil(len(imgs) / rows)
|
||||
|
||||
|
@ -131,7 +128,7 @@ class GridAnnotation:
|
|||
self.size = None
|
||||
|
||||
|
||||
def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0):
|
||||
def draw_grid_annotations(im, width, height, hor_texts, ver_texts):
|
||||
def wrap(drawing, text, font, line_length):
|
||||
lines = ['']
|
||||
for word in text.split():
|
||||
|
@ -195,35 +192,32 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0):
|
|||
line.allowed_width = allowed_width
|
||||
|
||||
hor_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing for lines in hor_texts]
|
||||
ver_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing * len(lines) for lines in ver_texts]
|
||||
ver_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing * len(lines) for lines in
|
||||
ver_texts]
|
||||
|
||||
pad_top = 0 if sum(hor_text_heights) == 0 else max(hor_text_heights) + line_spacing * 2
|
||||
pad_top = max(hor_text_heights) + line_spacing * 2
|
||||
|
||||
result = Image.new("RGB", (im.width + pad_left + margin * (cols-1), im.height + pad_top + margin * (rows-1)), "white")
|
||||
|
||||
for row in range(rows):
|
||||
for col in range(cols):
|
||||
cell = im.crop((width * col, height * row, width * (col+1), height * (row+1)))
|
||||
result.paste(cell, (pad_left + (width + margin) * col, pad_top + (height + margin) * row))
|
||||
result = Image.new("RGB", (im.width + pad_left, im.height + pad_top), "white")
|
||||
result.paste(im, (pad_left, pad_top))
|
||||
|
||||
d = ImageDraw.Draw(result)
|
||||
|
||||
for col in range(cols):
|
||||
x = pad_left + (width + margin) * col + width / 2
|
||||
x = pad_left + width * col + width / 2
|
||||
y = pad_top / 2 - hor_text_heights[col] / 2
|
||||
|
||||
draw_texts(d, x, y, hor_texts[col], fnt, fontsize)
|
||||
|
||||
for row in range(rows):
|
||||
x = pad_left / 2
|
||||
y = pad_top + (height + margin) * row + height / 2 - ver_text_heights[row] / 2
|
||||
y = pad_top + height * row + height / 2 - ver_text_heights[row] / 2
|
||||
|
||||
draw_texts(d, x, y, ver_texts[row], fnt, fontsize)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def draw_prompt_matrix(im, width, height, all_prompts, margin=0):
|
||||
def draw_prompt_matrix(im, width, height, all_prompts):
|
||||
prompts = all_prompts[1:]
|
||||
boundary = math.ceil(len(prompts) / 2)
|
||||
|
||||
|
@ -233,7 +227,7 @@ def draw_prompt_matrix(im, width, height, all_prompts, margin=0):
|
|||
hor_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_horiz)] for pos in range(1 << len(prompts_horiz))]
|
||||
ver_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_vert)] for pos in range(1 << len(prompts_vert))]
|
||||
|
||||
return draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin)
|
||||
return draw_grid_annotations(im, width, height, hor_texts, ver_texts)
|
||||
|
||||
|
||||
def resize_image(resize_mode, im, width, height, upscaler_name=None):
|
||||
|
@ -344,7 +338,6 @@ class FilenameGenerator:
|
|||
'date': lambda self: datetime.datetime.now().strftime('%Y-%m-%d'),
|
||||
'datetime': lambda self, *args: self.datetime(*args), # accepts formats: [datetime], [datetime<Format>], [datetime<Format><Time Zone>]
|
||||
'job_timestamp': lambda self: getattr(self.p, "job_timestamp", shared.state.job_timestamp),
|
||||
'prompt_hash': lambda self: hashlib.sha256(self.prompt.encode()).hexdigest()[0:8],
|
||||
'prompt': lambda self: sanitize_filename_part(self.prompt),
|
||||
'prompt_no_styles': lambda self: self.prompt_no_style(),
|
||||
'prompt_spaces': lambda self: sanitize_filename_part(self.prompt, replace_spaces=False),
|
||||
|
@ -612,9 +605,8 @@ def read_info_from_image(image):
|
|||
except ValueError:
|
||||
exif_comment = exif_comment.decode('utf8', errors="ignore")
|
||||
|
||||
if exif_comment:
|
||||
items['exif comment'] = exif_comment
|
||||
geninfo = exif_comment
|
||||
items['exif comment'] = exif_comment
|
||||
geninfo = exif_comment
|
||||
|
||||
for field in ['jfif', 'jfif_version', 'jfif_unit', 'jfif_density', 'dpi', 'exif',
|
||||
'loop', 'background', 'timestamp', 'duration']:
|
||||
|
|
|
@ -7,7 +7,6 @@ import numpy as np
|
|||
from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops
|
||||
|
||||
from modules import devices, sd_samplers
|
||||
from modules.generation_parameters_copypaste import create_override_settings_dict
|
||||
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
|
||||
from modules.shared import opts, state
|
||||
import modules.shared as shared
|
||||
|
@ -17,18 +16,11 @@ import modules.images as images
|
|||
import modules.scripts
|
||||
|
||||
|
||||
def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args):
|
||||
def process_batch(p, input_dir, output_dir, args):
|
||||
processing.fix_seed(p)
|
||||
|
||||
images = shared.listfiles(input_dir)
|
||||
|
||||
is_inpaint_batch = False
|
||||
if inpaint_mask_dir:
|
||||
inpaint_masks = shared.listfiles(inpaint_mask_dir)
|
||||
is_inpaint_batch = len(inpaint_masks) > 0
|
||||
if is_inpaint_batch:
|
||||
print(f"\nInpaint batch is enabled. {len(inpaint_masks)} masks found.")
|
||||
|
||||
print(f"Will process {len(images)} images, creating {p.n_iter * p.batch_size} new images for each.")
|
||||
|
||||
save_normally = output_dir == ''
|
||||
|
@ -51,15 +43,6 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args):
|
|||
img = ImageOps.exif_transpose(img)
|
||||
p.init_images = [img] * p.batch_size
|
||||
|
||||
if is_inpaint_batch:
|
||||
# try to find corresponding mask for an image using simple filename matching
|
||||
mask_image_path = os.path.join(inpaint_mask_dir, os.path.basename(image))
|
||||
# if not found use first one ("same mask for all images" use-case)
|
||||
if not mask_image_path in inpaint_masks:
|
||||
mask_image_path = inpaint_masks[0]
|
||||
mask_image = Image.open(mask_image_path)
|
||||
p.image_mask = mask_image
|
||||
|
||||
proc = modules.scripts.scripts_img2img.run(p, *args)
|
||||
if proc is None:
|
||||
proc = process_images(p)
|
||||
|
@ -76,9 +59,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args):
|
|||
processed_image.save(os.path.join(output_dir, filename))
|
||||
|
||||
|
||||
def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, *args):
|
||||
override_settings = create_override_settings_dict(override_settings_texts)
|
||||
|
||||
def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args):
|
||||
is_batch = mode == 5
|
||||
|
||||
if mode == 0: # img2img
|
||||
|
@ -120,7 +101,7 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
|
|||
outpath_grids=opts.outdir_grids or opts.outdir_img2img_grids,
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
styles=prompt_styles,
|
||||
styles=[prompt_style, prompt_style2],
|
||||
seed=seed,
|
||||
subseed=subseed,
|
||||
subseed_strength=subseed_strength,
|
||||
|
@ -142,11 +123,9 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
|
|||
inpainting_fill=inpainting_fill,
|
||||
resize_mode=resize_mode,
|
||||
denoising_strength=denoising_strength,
|
||||
image_cfg_scale=image_cfg_scale,
|
||||
inpaint_full_res=inpaint_full_res,
|
||||
inpaint_full_res_padding=inpaint_full_res_padding,
|
||||
inpainting_mask_invert=inpainting_mask_invert,
|
||||
override_settings=override_settings,
|
||||
)
|
||||
|
||||
p.scripts = modules.scripts.scripts_txt2img
|
||||
|
@ -160,7 +139,7 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
|
|||
if is_batch:
|
||||
assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled"
|
||||
|
||||
process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, args)
|
||||
process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, args)
|
||||
|
||||
processed = Processed(p, [], p.seed, "")
|
||||
else:
|
||||
|
|
|
@ -2,17 +2,15 @@ import os
|
|||
import sys
|
||||
import traceback
|
||||
from collections import namedtuple
|
||||
from pathlib import Path
|
||||
import re
|
||||
|
||||
import torch
|
||||
import torch.hub
|
||||
|
||||
from torchvision import transforms
|
||||
from torchvision.transforms.functional import InterpolationMode
|
||||
|
||||
import modules.shared as shared
|
||||
from modules import devices, paths, shared, lowvram, modelloader, errors
|
||||
from modules import devices, paths, lowvram, modelloader
|
||||
|
||||
blip_image_eval_size = 384
|
||||
clip_model_name = 'ViT-L/14'
|
||||
|
@ -21,76 +19,30 @@ Category = namedtuple("Category", ["name", "topn", "items"])
|
|||
|
||||
re_topn = re.compile(r"\.top(\d+)\.")
|
||||
|
||||
def category_types():
|
||||
return [f.stem for f in Path(shared.interrogator.content_dir).glob('*.txt')]
|
||||
|
||||
|
||||
def download_default_clip_interrogate_categories(content_dir):
|
||||
print("Downloading CLIP categories...")
|
||||
|
||||
tmpdir = content_dir + "_tmp"
|
||||
category_types = ["artists", "flavors", "mediums", "movements"]
|
||||
|
||||
try:
|
||||
os.makedirs(tmpdir)
|
||||
for category_type in category_types:
|
||||
torch.hub.download_url_to_file(f"https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/{category_type}.txt", os.path.join(tmpdir, f"{category_type}.txt"))
|
||||
os.rename(tmpdir, content_dir)
|
||||
|
||||
except Exception as e:
|
||||
errors.display(e, "downloading default CLIP interrogate categories")
|
||||
finally:
|
||||
if os.path.exists(tmpdir):
|
||||
os.remove(tmpdir)
|
||||
|
||||
|
||||
class InterrogateModels:
|
||||
blip_model = None
|
||||
clip_model = None
|
||||
clip_preprocess = None
|
||||
categories = None
|
||||
dtype = None
|
||||
running_on_cpu = None
|
||||
|
||||
def __init__(self, content_dir):
|
||||
self.loaded_categories = None
|
||||
self.skip_categories = []
|
||||
self.content_dir = content_dir
|
||||
self.categories = []
|
||||
self.running_on_cpu = devices.device_interrogate == torch.device("cpu")
|
||||
|
||||
def categories(self):
|
||||
if not os.path.exists(self.content_dir):
|
||||
download_default_clip_interrogate_categories(self.content_dir)
|
||||
|
||||
if self.loaded_categories is not None and self.skip_categories == shared.opts.interrogate_clip_skip_categories:
|
||||
return self.loaded_categories
|
||||
|
||||
self.loaded_categories = []
|
||||
|
||||
if os.path.exists(self.content_dir):
|
||||
self.skip_categories = shared.opts.interrogate_clip_skip_categories
|
||||
category_types = []
|
||||
for filename in Path(self.content_dir).glob('*.txt'):
|
||||
category_types.append(filename.stem)
|
||||
if filename.stem in self.skip_categories:
|
||||
continue
|
||||
m = re_topn.search(filename.stem)
|
||||
if os.path.exists(content_dir):
|
||||
for filename in os.listdir(content_dir):
|
||||
m = re_topn.search(filename)
|
||||
topn = 1 if m is None else int(m.group(1))
|
||||
with open(filename, "r", encoding="utf8") as file:
|
||||
|
||||
with open(os.path.join(content_dir, filename), "r", encoding="utf8") as file:
|
||||
lines = [x.strip() for x in file.readlines()]
|
||||
|
||||
self.loaded_categories.append(Category(name=filename.stem, topn=topn, items=lines))
|
||||
|
||||
return self.loaded_categories
|
||||
|
||||
def create_fake_fairscale(self):
|
||||
class FakeFairscale:
|
||||
def checkpoint_wrapper(self):
|
||||
pass
|
||||
|
||||
sys.modules["fairscale.nn.checkpoint.checkpoint_activations"] = FakeFairscale
|
||||
self.categories.append(Category(name=filename, topn=topn, items=lines))
|
||||
|
||||
def load_blip_model(self):
|
||||
self.create_fake_fairscale()
|
||||
import models.blip
|
||||
|
||||
files = modelloader.load_models(
|
||||
|
@ -154,8 +106,6 @@ class InterrogateModels:
|
|||
def rank(self, image_features, text_array, top_count=1):
|
||||
import clip
|
||||
|
||||
devices.torch_gc()
|
||||
|
||||
if shared.opts.interrogate_clip_dict_limit != 0:
|
||||
text_array = text_array[0:int(shared.opts.interrogate_clip_dict_limit)]
|
||||
|
||||
|
@ -189,6 +139,7 @@ class InterrogateModels:
|
|||
shared.state.begin()
|
||||
shared.state.job = 'interrogate'
|
||||
try:
|
||||
|
||||
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
||||
lowvram.send_everything_to_cpu()
|
||||
devices.torch_gc()
|
||||
|
@ -208,7 +159,12 @@ class InterrogateModels:
|
|||
|
||||
image_features /= image_features.norm(dim=-1, keepdim=True)
|
||||
|
||||
for name, topn, items in self.categories():
|
||||
if shared.opts.interrogate_use_builtin_artists:
|
||||
artist = self.rank(image_features, ["by " + artist.name for artist in shared.artist_db.artists])[0]
|
||||
|
||||
res += ", " + artist[0]
|
||||
|
||||
for name, topn, items in self.categories:
|
||||
matches = self.rank(image_features, items, top_count=topn)
|
||||
for match, score in matches:
|
||||
if shared.opts.interrogate_return_ranks:
|
||||
|
|
|
@ -1,53 +0,0 @@
|
|||
import torch
|
||||
from modules import paths
|
||||
from modules.sd_hijack_utils import CondFunc
|
||||
from packaging import version
|
||||
|
||||
|
||||
# has_mps is only available in nightly pytorch (for now) and macOS 12.3+.
|
||||
# check `getattr` and try it for compatibility
|
||||
def check_for_mps() -> bool:
|
||||
if not getattr(torch, 'has_mps', False):
|
||||
return False
|
||||
try:
|
||||
torch.zeros(1).to(torch.device("mps"))
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
has_mps = check_for_mps()
|
||||
|
||||
|
||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/89784
|
||||
def cumsum_fix(input, cumsum_func, *args, **kwargs):
|
||||
if input.device.type == 'mps':
|
||||
output_dtype = kwargs.get('dtype', input.dtype)
|
||||
if output_dtype == torch.int64:
|
||||
return cumsum_func(input.cpu(), *args, **kwargs).to(input.device)
|
||||
elif cumsum_needs_bool_fix and output_dtype == torch.bool or cumsum_needs_int_fix and (output_dtype == torch.int8 or output_dtype == torch.int16):
|
||||
return cumsum_func(input.to(torch.int32), *args, **kwargs).to(torch.int64)
|
||||
return cumsum_func(input, *args, **kwargs)
|
||||
|
||||
|
||||
if has_mps:
|
||||
# MPS fix for randn in torchsde
|
||||
CondFunc('torchsde._brownian.brownian_interval._randn', lambda _, size, dtype, device, seed: torch.randn(size, dtype=dtype, device=torch.device("cpu"), generator=torch.Generator(torch.device("cpu")).manual_seed(int(seed))).to(device), lambda _, size, dtype, device, seed: device.type == 'mps')
|
||||
|
||||
if version.parse(torch.__version__) < version.parse("1.13"):
|
||||
# PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working
|
||||
|
||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/79383
|
||||
CondFunc('torch.Tensor.to', lambda orig_func, self, *args, **kwargs: orig_func(self.contiguous(), *args, **kwargs),
|
||||
lambda _, self, *args, **kwargs: self.device.type != 'mps' and (args and isinstance(args[0], torch.device) and args[0].type == 'mps' or isinstance(kwargs.get('device'), torch.device) and kwargs['device'].type == 'mps'))
|
||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/80800
|
||||
CondFunc('torch.nn.functional.layer_norm', lambda orig_func, *args, **kwargs: orig_func(*([args[0].contiguous()] + list(args[1:])), **kwargs),
|
||||
lambda _, *args, **kwargs: args and isinstance(args[0], torch.Tensor) and args[0].device.type == 'mps')
|
||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/90532
|
||||
CondFunc('torch.Tensor.numpy', lambda orig_func, self, *args, **kwargs: orig_func(self.detach(), *args, **kwargs), lambda _, self, *args, **kwargs: self.requires_grad)
|
||||
elif version.parse(torch.__version__) > version.parse("1.13.1"):
|
||||
cumsum_needs_int_fix = not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.ShortTensor([1,1]).to(torch.device("mps")).cumsum(0))
|
||||
cumsum_needs_bool_fix = not torch.BoolTensor([True,True]).to(device=torch.device("mps"), dtype=torch.int64).equal(torch.BoolTensor([True,False]).to(torch.device("mps")).cumsum(0))
|
||||
cumsum_fix_func = lambda orig_func, input, *args, **kwargs: cumsum_fix(input, orig_func, *args, **kwargs)
|
||||
CondFunc('torch.cumsum', cumsum_fix_func, None)
|
||||
CondFunc('torch.Tensor.cumsum', cumsum_fix_func, None)
|
||||
CondFunc('torch.narrow', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).clone(), None)
|
||||
|
|
@ -45,9 +45,6 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
|
|||
full_path = file
|
||||
if os.path.isdir(full_path):
|
||||
continue
|
||||
if os.path.islink(full_path) and not os.path.exists(full_path):
|
||||
print(f"Skipping broken symlink: {full_path}")
|
||||
continue
|
||||
if ext_blacklist is not None and any([full_path.endswith(x) for x in ext_blacklist]):
|
||||
continue
|
||||
if len(ext_filter) != 0:
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -4,15 +4,7 @@ import sys
|
|||
import modules.safe
|
||||
|
||||
script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
|
||||
# Parse the --data-dir flag first so we can use it as a base for our other argument default values
|
||||
parser = argparse.ArgumentParser(add_help=False)
|
||||
parser.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored",)
|
||||
cmd_opts_pre = parser.parse_known_args()[0]
|
||||
data_path = cmd_opts_pre.data_dir
|
||||
models_path = os.path.join(data_path, "models")
|
||||
|
||||
# data_path = cmd_opts_pre.data
|
||||
models_path = os.path.join(script_path, "models")
|
||||
sys.path.insert(0, script_path)
|
||||
|
||||
# search for directory of stable diffusion in following places
|
||||
|
@ -46,17 +38,3 @@ for d, must_exist, what, options in path_dirs:
|
|||
else:
|
||||
sys.path.append(d)
|
||||
paths[what] = d
|
||||
|
||||
|
||||
class Prioritize:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
self.path = None
|
||||
|
||||
def __enter__(self):
|
||||
self.path = sys.path.copy()
|
||||
sys.path = [paths[self.name]] + sys.path
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
sys.path = self.path
|
||||
self.path = None
|
||||
|
|
|
@ -1,103 +0,0 @@
|
|||
import os
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from modules import shared, images, devices, scripts, scripts_postprocessing, ui_common, generation_parameters_copypaste
|
||||
from modules.shared import opts
|
||||
|
||||
|
||||
def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output: bool = True):
|
||||
devices.torch_gc()
|
||||
|
||||
shared.state.begin()
|
||||
shared.state.job = 'extras'
|
||||
|
||||
image_data = []
|
||||
image_names = []
|
||||
outputs = []
|
||||
|
||||
if extras_mode == 1:
|
||||
for img in image_folder:
|
||||
image = Image.open(img)
|
||||
image_data.append(image)
|
||||
image_names.append(os.path.splitext(img.orig_name)[0])
|
||||
elif extras_mode == 2:
|
||||
assert not shared.cmd_opts.hide_ui_dir_config, '--hide-ui-dir-config option must be disabled'
|
||||
assert input_dir, 'input directory not selected'
|
||||
|
||||
image_list = shared.listfiles(input_dir)
|
||||
for filename in image_list:
|
||||
try:
|
||||
image = Image.open(filename)
|
||||
except Exception:
|
||||
continue
|
||||
image_data.append(image)
|
||||
image_names.append(filename)
|
||||
else:
|
||||
assert image, 'image not selected'
|
||||
|
||||
image_data.append(image)
|
||||
image_names.append(None)
|
||||
|
||||
if extras_mode == 2 and output_dir != '':
|
||||
outpath = output_dir
|
||||
else:
|
||||
outpath = opts.outdir_samples or opts.outdir_extras_samples
|
||||
|
||||
infotext = ''
|
||||
|
||||
for image, name in zip(image_data, image_names):
|
||||
shared.state.textinfo = name
|
||||
|
||||
existing_pnginfo = image.info or {}
|
||||
|
||||
pp = scripts_postprocessing.PostprocessedImage(image.convert("RGB"))
|
||||
|
||||
scripts.scripts_postproc.run(pp, args)
|
||||
|
||||
if opts.use_original_name_batch and name is not None:
|
||||
basename = os.path.splitext(os.path.basename(name))[0]
|
||||
else:
|
||||
basename = ''
|
||||
|
||||
infotext = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in pp.info.items() if v is not None])
|
||||
|
||||
if opts.enable_pnginfo:
|
||||
pp.image.info = existing_pnginfo
|
||||
pp.image.info["postprocessing"] = infotext
|
||||
|
||||
if save_output:
|
||||
images.save_image(pp.image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=infotext, short_filename=True, no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None)
|
||||
|
||||
if extras_mode != 2 or show_extras_results:
|
||||
outputs.append(pp.image)
|
||||
|
||||
devices.torch_gc()
|
||||
|
||||
return outputs, ui_common.plaintext_to_html(infotext), ''
|
||||
|
||||
|
||||
def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True):
|
||||
"""old handler for API"""
|
||||
|
||||
args = scripts.scripts_postproc.create_args_for_run({
|
||||
"Upscale": {
|
||||
"upscale_mode": resize_mode,
|
||||
"upscale_by": upscaling_resize,
|
||||
"upscale_to_width": upscaling_resize_w,
|
||||
"upscale_to_height": upscaling_resize_h,
|
||||
"upscale_crop": upscaling_crop,
|
||||
"upscaler_1_name": extras_upscaler_1,
|
||||
"upscaler_2_name": extras_upscaler_2,
|
||||
"upscaler_2_visibility": extras_upscaler_2_visibility,
|
||||
},
|
||||
"GFPGAN": {
|
||||
"gfpgan_visibility": gfpgan_visibility,
|
||||
},
|
||||
"CodeFormer": {
|
||||
"codeformer_visibility": codeformer_visibility,
|
||||
"codeformer_weight": codeformer_weight,
|
||||
},
|
||||
})
|
||||
|
||||
return run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output=save_output)
|
|
@ -13,11 +13,10 @@ from skimage import exposure
|
|||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import modules.sd_hijack
|
||||
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks, extra_networks, sd_vae_approx, scripts
|
||||
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks
|
||||
from modules.sd_hijack import model_hijack
|
||||
from modules.shared import opts, cmd_opts, state
|
||||
import modules.shared as shared
|
||||
import modules.paths as paths
|
||||
import modules.face_restoration
|
||||
import modules.images as images
|
||||
import modules.styles
|
||||
|
@ -95,7 +94,7 @@ def txt2img_image_conditioning(sd_model, x, width, height):
|
|||
return image_conditioning
|
||||
|
||||
|
||||
class StableDiffusionProcessing:
|
||||
class StableDiffusionProcessing():
|
||||
"""
|
||||
The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing
|
||||
"""
|
||||
|
@ -103,6 +102,7 @@ class StableDiffusionProcessing:
|
|||
if sampler_index is not None:
|
||||
print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr)
|
||||
|
||||
self.sd_model = sd_model
|
||||
self.outpath_samples: str = outpath_samples
|
||||
self.outpath_grids: str = outpath_grids
|
||||
self.prompt: str = prompt
|
||||
|
@ -141,7 +141,6 @@ class StableDiffusionProcessing:
|
|||
self.override_settings = {k: v for k, v in (override_settings or {}).items() if k not in shared.restricted_opts}
|
||||
self.override_settings_restore_afterwards = override_settings_restore_afterwards
|
||||
self.is_using_inpainting_conditioning = False
|
||||
self.disable_extra_networks = False
|
||||
|
||||
if not seed_enable_extras:
|
||||
self.subseed = -1
|
||||
|
@ -157,10 +156,6 @@ class StableDiffusionProcessing:
|
|||
self.all_subseeds = None
|
||||
self.iteration = 0
|
||||
|
||||
@property
|
||||
def sd_model(self):
|
||||
return shared.sd_model
|
||||
|
||||
def txt2img_image_conditioning(self, x, width=None, height=None):
|
||||
self.is_using_inpainting_conditioning = self.sd_model.model.conditioning_key in {'hybrid', 'concat'}
|
||||
|
||||
|
@ -185,12 +180,7 @@ class StableDiffusionProcessing:
|
|||
conditioning = 2. * (conditioning - depth_min) / (depth_max - depth_min) - 1.
|
||||
return conditioning
|
||||
|
||||
def edit_image_conditioning(self, source_image):
|
||||
conditioning_image = self.sd_model.encode_first_stage(source_image).mode()
|
||||
|
||||
return conditioning_image
|
||||
|
||||
def inpainting_image_conditioning(self, source_image, latent_image, image_mask=None):
|
||||
def inpainting_image_conditioning(self, source_image, latent_image, image_mask = None):
|
||||
self.is_using_inpainting_conditioning = True
|
||||
|
||||
# Handle the different mask inputs
|
||||
|
@ -209,7 +199,7 @@ class StableDiffusionProcessing:
|
|||
|
||||
# Create another latent image, this time with a masked version of the original input.
|
||||
# Smoothly interpolate between the masked and unmasked latent conditioning image using a parameter.
|
||||
conditioning_mask = conditioning_mask.to(device=source_image.device, dtype=source_image.dtype)
|
||||
conditioning_mask = conditioning_mask.to(source_image.device).to(source_image.dtype)
|
||||
conditioning_image = torch.lerp(
|
||||
source_image,
|
||||
source_image * (1.0 - conditioning_mask),
|
||||
|
@ -228,16 +218,11 @@ class StableDiffusionProcessing:
|
|||
return image_conditioning
|
||||
|
||||
def img2img_image_conditioning(self, source_image, latent_image, image_mask=None):
|
||||
source_image = devices.cond_cast_float(source_image)
|
||||
|
||||
# HACK: Using introspection as the Depth2Image model doesn't appear to uniquely
|
||||
# identify itself with a field common to all models. The conditioning_key is also hybrid.
|
||||
if isinstance(self.sd_model, LatentDepth2ImageDiffusion):
|
||||
return self.depth2img_image_conditioning(source_image)
|
||||
|
||||
if self.sd_model.cond_stage_key == "edit":
|
||||
return self.edit_image_conditioning(source_image)
|
||||
|
||||
if self.sampler.conditioning_key in {'hybrid', 'concat'}:
|
||||
return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask)
|
||||
|
||||
|
@ -251,6 +236,7 @@ class StableDiffusionProcessing:
|
|||
raise NotImplementedError()
|
||||
|
||||
def close(self):
|
||||
self.sd_model = None
|
||||
self.sampler = None
|
||||
|
||||
|
||||
|
@ -268,7 +254,6 @@ class Processed:
|
|||
self.height = p.height
|
||||
self.sampler_name = p.sampler_name
|
||||
self.cfg_scale = p.cfg_scale
|
||||
self.image_cfg_scale = getattr(p, 'image_cfg_scale', None)
|
||||
self.steps = p.steps
|
||||
self.batch_size = p.batch_size
|
||||
self.restore_faces = p.restore_faces
|
||||
|
@ -446,17 +431,22 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
|
|||
"Steps": p.steps,
|
||||
"Sampler": p.sampler_name,
|
||||
"CFG scale": p.cfg_scale,
|
||||
"Image CFG scale": getattr(p, 'image_cfg_scale', None),
|
||||
"Seed": all_seeds[index],
|
||||
"Face restoration": (opts.face_restoration_model if p.restore_faces else None),
|
||||
"Size": f"{p.width}x{p.height}",
|
||||
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
|
||||
"Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
|
||||
"Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name),
|
||||
"Hypernet hash": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.shorthash()),
|
||||
"Hypernet strength": (None if shared.loaded_hypernetwork is None or shared.opts.sd_hypernetwork_strength >= 1 else shared.opts.sd_hypernetwork_strength),
|
||||
"Batch size": (None if p.batch_size < 2 else p.batch_size),
|
||||
"Batch pos": (None if p.batch_size < 2 else position_in_batch),
|
||||
"Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
|
||||
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
|
||||
"Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
|
||||
"Denoising strength": getattr(p, 'denoising_strength', None),
|
||||
"Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None,
|
||||
"Eta": (None if p.sampler is None or p.sampler.eta == p.sampler.default_eta else p.sampler.eta),
|
||||
"Clip skip": None if clip_skip <= 1 else clip_skip,
|
||||
"ENSD": None if opts.eta_noise_seed_delta == 0 else opts.eta_noise_seed_delta,
|
||||
}
|
||||
|
@ -476,12 +466,15 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
|||
try:
|
||||
for k, v in p.override_settings.items():
|
||||
setattr(opts, k, v)
|
||||
if k == 'sd_hypernetwork':
|
||||
shared.reload_hypernetworks() # make onchange call for changing hypernet
|
||||
|
||||
if k == 'sd_model_checkpoint':
|
||||
sd_models.reload_model_weights()
|
||||
sd_models.reload_model_weights() # make onchange call for changing SD model
|
||||
p.sd_model = shared.sd_model
|
||||
|
||||
if k == 'sd_vae':
|
||||
sd_vae.reload_vae_weights()
|
||||
sd_vae.reload_vae_weights() # make onchange call for changing VAE
|
||||
|
||||
res = process_images_inner(p)
|
||||
|
||||
|
@ -490,11 +483,9 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
|||
if p.override_settings_restore_afterwards:
|
||||
for k, v in stored_opts.items():
|
||||
setattr(opts, k, v)
|
||||
if k == 'sd_model_checkpoint':
|
||||
sd_models.reload_model_weights()
|
||||
|
||||
if k == 'sd_vae':
|
||||
sd_vae.reload_vae_weights()
|
||||
if k == 'sd_hypernetwork': shared.reload_hypernetworks()
|
||||
if k == 'sd_model_checkpoint': sd_models.reload_model_weights()
|
||||
if k == 'sd_vae': sd_vae.reload_vae_weights()
|
||||
|
||||
return res
|
||||
|
||||
|
@ -543,11 +534,13 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||
if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings:
|
||||
model_hijack.embedding_db.load_textual_inversion_embeddings()
|
||||
|
||||
_, extra_network_data = extra_networks.parse_prompts(p.all_prompts[0:1])
|
||||
|
||||
if p.scripts is not None:
|
||||
p.scripts.process(p)
|
||||
|
||||
with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file:
|
||||
processed = Processed(p, [], p.seed, "")
|
||||
file.write(processed.infotext(p, 0))
|
||||
|
||||
infotexts = []
|
||||
output_images = []
|
||||
|
||||
|
@ -578,17 +571,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||
with devices.autocast():
|
||||
p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
|
||||
|
||||
# for OSX, loading the model during sampling changes the generated picture, so it is loaded here
|
||||
if shared.opts.live_previews_enable and opts.show_progress_type == "Approx NN":
|
||||
sd_vae_approx.model()
|
||||
|
||||
if not p.disable_extra_networks:
|
||||
extra_networks.activate(p, extra_network_data)
|
||||
|
||||
with open(os.path.join(paths.data_path, "params.txt"), "w", encoding="utf8") as file:
|
||||
processed = Processed(p, [], p.seed, "")
|
||||
file.write(processed.infotext(p, 0))
|
||||
|
||||
if state.job_count == -1:
|
||||
state.job_count = p.n_iter
|
||||
|
||||
|
@ -609,8 +591,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||
if len(prompts) == 0:
|
||||
break
|
||||
|
||||
prompts, _ = extra_networks.parse_prompts(prompts)
|
||||
|
||||
if p.scripts is not None:
|
||||
p.scripts.process_batch(p, batch_number=n, prompts=prompts, seeds=seeds, subseeds=subseeds)
|
||||
|
||||
|
@ -624,13 +604,10 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||
if p.n_iter > 1:
|
||||
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
|
||||
|
||||
with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
|
||||
with devices.autocast():
|
||||
samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, prompts=prompts)
|
||||
|
||||
x_samples_ddim = [decode_first_stage(p.sd_model, samples_ddim[i:i+1].to(dtype=devices.dtype_vae))[0].cpu() for i in range(samples_ddim.size(0))]
|
||||
for x in x_samples_ddim:
|
||||
devices.test_for_nans(x, "vae")
|
||||
|
||||
x_samples_ddim = torch.stack(x_samples_ddim).float()
|
||||
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
|
||||
|
@ -659,11 +636,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||
|
||||
image = Image.fromarray(x_sample)
|
||||
|
||||
if p.scripts is not None:
|
||||
pp = scripts.PostprocessImageArgs(image)
|
||||
p.scripts.postprocess_image(p, pp)
|
||||
image = pp.image
|
||||
|
||||
if p.color_corrections is not None and i < len(p.color_corrections):
|
||||
if opts.save and not p.do_not_save_samples and opts.save_images_before_color_correction:
|
||||
image_without_cc = apply_overlay(image, p.paste_to, i, p.overlay_images)
|
||||
|
@ -705,9 +677,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||
if opts.grid_save:
|
||||
images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
|
||||
|
||||
if not p.disable_extra_networks:
|
||||
extra_networks.deactivate(p, extra_network_data)
|
||||
|
||||
devices.torch_gc()
|
||||
|
||||
res = Processed(p, output_images, p.all_seeds[0], infotext(), comments="".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], index_of_first_image=index_of_first_image, infotexts=infotexts)
|
||||
|
@ -884,8 +853,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||
|
||||
shared.state.nextjob()
|
||||
|
||||
img2img_sampler_name = self.sampler_name if self.sampler_name != 'PLMS' else 'DDIM' # PLMS does not support img2img so we just silently switch ot DDIM
|
||||
self.sampler = sd_samplers.create_sampler(img2img_sampler_name, self.sd_model)
|
||||
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
|
||||
|
||||
samples = samples[:, :, self.truncate_y//2:samples.shape[2]-(self.truncate_y+1)//2, self.truncate_x//2:samples.shape[3]-(self.truncate_x+1)//2]
|
||||
|
||||
|
@ -903,13 +871,12 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||
class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
||||
sampler = None
|
||||
|
||||
def __init__(self, init_images: list = None, resize_mode: int = 0, denoising_strength: float = 0.75, image_cfg_scale: float = None, mask: Any = None, mask_blur: int = 4, inpainting_fill: int = 0, inpaint_full_res: bool = True, inpaint_full_res_padding: int = 0, inpainting_mask_invert: int = 0, initial_noise_multiplier: float = None, **kwargs):
|
||||
def __init__(self, init_images: list = None, resize_mode: int = 0, denoising_strength: float = 0.75, mask: Any = None, mask_blur: int = 4, inpainting_fill: int = 0, inpaint_full_res: bool = True, inpaint_full_res_padding: int = 0, inpainting_mask_invert: int = 0, initial_noise_multiplier: float = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.init_images = init_images
|
||||
self.resize_mode: int = resize_mode
|
||||
self.denoising_strength: float = denoising_strength
|
||||
self.image_cfg_scale: float = image_cfg_scale if shared.sd_model.cond_stage_key == "edit" else None
|
||||
self.init_latent = None
|
||||
self.image_mask = mask
|
||||
self.latent_mask = None
|
||||
|
|
|
@ -67,13 +67,10 @@ def progressapi(req: ProgressRequest):
|
|||
|
||||
progress = 0
|
||||
|
||||
job_count, job_no = shared.state.job_count, shared.state.job_no
|
||||
sampling_steps, sampling_step = shared.state.sampling_steps, shared.state.sampling_step
|
||||
|
||||
if job_count > 0:
|
||||
progress += job_no / job_count
|
||||
if sampling_steps > 0 and job_count > 0:
|
||||
progress += 1 / job_count * sampling_step / sampling_steps
|
||||
if shared.state.job_count > 0:
|
||||
progress += shared.state.job_no / shared.state.job_count
|
||||
if shared.state.sampling_steps > 0:
|
||||
progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps
|
||||
|
||||
progress = min(progress, 1)
|
||||
|
||||
|
|
|
@ -274,7 +274,6 @@ re_attention = re.compile(r"""
|
|||
:
|
||||
""", re.X)
|
||||
|
||||
re_break = re.compile(r"\s*\bBREAK\b\s*", re.S)
|
||||
|
||||
def parse_prompt_attention(text):
|
||||
"""
|
||||
|
@ -340,11 +339,7 @@ def parse_prompt_attention(text):
|
|||
elif text == ']' and len(square_brackets) > 0:
|
||||
multiply_range(square_brackets.pop(), square_bracket_multiplier)
|
||||
else:
|
||||
parts = re.split(re_break, text)
|
||||
for i, part in enumerate(parts):
|
||||
if i > 0:
|
||||
res.append(["BREAK", -1])
|
||||
res.append([part, 1.0])
|
||||
res.append([text, 1.0])
|
||||
|
||||
for pos in round_brackets:
|
||||
multiply_range(pos, round_bracket_multiplier)
|
||||
|
|
|
@ -38,15 +38,15 @@ class UpscalerRealESRGAN(Upscaler):
|
|||
return img
|
||||
|
||||
info = self.load_model(path)
|
||||
if not os.path.exists(info.local_data_path):
|
||||
if not os.path.exists(info.data_path):
|
||||
print("Unable to load RealESRGAN model: %s" % info.name)
|
||||
return img
|
||||
|
||||
upsampler = RealESRGANer(
|
||||
scale=info.scale,
|
||||
model_path=info.local_data_path,
|
||||
model_path=info.data_path,
|
||||
model=info.model(),
|
||||
half=not cmd_opts.no_half and not cmd_opts.upcast_sampling,
|
||||
half=not cmd_opts.no_half,
|
||||
tile=opts.ESRGAN_tile,
|
||||
tile_pad=opts.ESRGAN_tile_overlap,
|
||||
)
|
||||
|
@ -58,13 +58,17 @@ class UpscalerRealESRGAN(Upscaler):
|
|||
|
||||
def load_model(self, path):
|
||||
try:
|
||||
info = next(iter([scaler for scaler in self.scalers if scaler.data_path == path]), None)
|
||||
info = None
|
||||
for scaler in self.scalers:
|
||||
if scaler.data_path == path:
|
||||
info = scaler
|
||||
|
||||
if info is None:
|
||||
print(f"Unable to find model info: {path}")
|
||||
return None
|
||||
|
||||
info.local_data_path = load_file_from_url(url=info.data_path, model_dir=self.model_path, progress=True)
|
||||
model_file = load_file_from_url(url=info.data_path, model_dir=self.model_path, progress=True)
|
||||
info.data_path = model_file
|
||||
return info
|
||||
except Exception as e:
|
||||
print(f"Error making Real-ESRGAN models list: {e}", file=sys.stderr)
|
||||
|
|
|
@ -73,7 +73,6 @@ callback_map = dict(
|
|||
callbacks_image_grid=[],
|
||||
callbacks_infotext_pasted=[],
|
||||
callbacks_script_unloaded=[],
|
||||
callbacks_before_ui=[],
|
||||
)
|
||||
|
||||
|
||||
|
@ -190,14 +189,6 @@ def script_unloaded_callback():
|
|||
report_exception(c, 'script_unloaded')
|
||||
|
||||
|
||||
def before_ui_callback():
|
||||
for c in reversed(callback_map['callbacks_before_ui']):
|
||||
try:
|
||||
c.callback()
|
||||
except Exception:
|
||||
report_exception(c, 'before_ui')
|
||||
|
||||
|
||||
def add_callback(callbacks, fun):
|
||||
stack = [x for x in inspect.stack() if x.filename != __file__]
|
||||
filename = stack[0].filename if len(stack) > 0 else 'unknown file'
|
||||
|
@ -322,9 +313,3 @@ def on_script_unloaded(callback):
|
|||
the script did should be reverted here"""
|
||||
|
||||
add_callback(callback_map['callbacks_script_unloaded'], callback)
|
||||
|
||||
|
||||
def on_before_ui(callback):
|
||||
"""register a function to be called before the UI is created."""
|
||||
|
||||
add_callback(callback_map['callbacks_before_ui'], callback)
|
||||
|
|
|
@ -1,14 +1,16 @@
|
|||
import os
|
||||
import sys
|
||||
import traceback
|
||||
import importlib.util
|
||||
from types import ModuleType
|
||||
|
||||
|
||||
def load_module(path):
|
||||
module_spec = importlib.util.spec_from_file_location(os.path.basename(path), path)
|
||||
module = importlib.util.module_from_spec(module_spec)
|
||||
module_spec.loader.exec_module(module)
|
||||
with open(path, "r", encoding="utf8") as file:
|
||||
text = file.read()
|
||||
|
||||
compiled = compile(text, path, 'exec')
|
||||
module = ModuleType(os.path.basename(path))
|
||||
exec(compiled, module.__dict__)
|
||||
|
||||
return module
|
||||
|
||||
|
|
|
@ -6,16 +6,12 @@ from collections import namedtuple
|
|||
|
||||
import gradio as gr
|
||||
|
||||
from modules import shared, paths, script_callbacks, extensions, script_loading, scripts_postprocessing
|
||||
from modules.processing import StableDiffusionProcessing
|
||||
from modules import shared, paths, script_callbacks, extensions, script_loading
|
||||
|
||||
AlwaysVisible = object()
|
||||
|
||||
|
||||
class PostprocessImageArgs:
|
||||
def __init__(self, image):
|
||||
self.image = image
|
||||
|
||||
|
||||
class Script:
|
||||
filename = None
|
||||
args_from = None
|
||||
|
@ -69,7 +65,7 @@ class Script:
|
|||
args contains all values returned by components from ui()
|
||||
"""
|
||||
|
||||
pass
|
||||
raise NotImplementedError()
|
||||
|
||||
def process(self, p, *args):
|
||||
"""
|
||||
|
@ -104,13 +100,6 @@ class Script:
|
|||
|
||||
pass
|
||||
|
||||
def postprocess_image(self, p, pp: PostprocessImageArgs, *args):
|
||||
"""
|
||||
Called for every image after it has been generated.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
def postprocess(self, p, processed, *args):
|
||||
"""
|
||||
This function is called after processing ends for AlwaysVisible scripts.
|
||||
|
@ -161,10 +150,8 @@ def basedir():
|
|||
return current_basedir
|
||||
|
||||
|
||||
ScriptFile = namedtuple("ScriptFile", ["basedir", "filename", "path"])
|
||||
|
||||
scripts_data = []
|
||||
postprocessing_scripts_data = []
|
||||
ScriptFile = namedtuple("ScriptFile", ["basedir", "filename", "path"])
|
||||
ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir", "module"])
|
||||
|
||||
|
||||
|
@ -203,31 +190,23 @@ def list_files_with_name(filename):
|
|||
def load_scripts():
|
||||
global current_basedir
|
||||
scripts_data.clear()
|
||||
postprocessing_scripts_data.clear()
|
||||
script_callbacks.clear_callbacks()
|
||||
|
||||
scripts_list = list_scripts("scripts", ".py")
|
||||
|
||||
syspath = sys.path
|
||||
|
||||
def register_scripts_from_module(module):
|
||||
for key, script_class in module.__dict__.items():
|
||||
if type(script_class) != type:
|
||||
continue
|
||||
|
||||
if issubclass(script_class, Script):
|
||||
scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module))
|
||||
elif issubclass(script_class, scripts_postprocessing.ScriptPostprocessing):
|
||||
postprocessing_scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module))
|
||||
|
||||
for scriptfile in sorted(scripts_list):
|
||||
try:
|
||||
if scriptfile.basedir != paths.script_path:
|
||||
sys.path = [scriptfile.basedir] + sys.path
|
||||
current_basedir = scriptfile.basedir
|
||||
|
||||
script_module = script_loading.load_module(scriptfile.path)
|
||||
register_scripts_from_module(script_module)
|
||||
module = script_loading.load_module(scriptfile.path)
|
||||
|
||||
for key, script_class in module.__dict__.items():
|
||||
if type(script_class) == type and issubclass(script_class, Script):
|
||||
scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module))
|
||||
|
||||
except Exception:
|
||||
print(f"Error loading script: {scriptfile.filename}", file=sys.stderr)
|
||||
|
@ -258,15 +237,11 @@ class ScriptRunner:
|
|||
self.infotext_fields = []
|
||||
|
||||
def initialize_scripts(self, is_img2img):
|
||||
from modules import scripts_auto_postprocessing
|
||||
|
||||
self.scripts.clear()
|
||||
self.alwayson_scripts.clear()
|
||||
self.selectable_scripts.clear()
|
||||
|
||||
auto_processing_scripts = scripts_auto_postprocessing.create_auto_preprocessing_script_data()
|
||||
|
||||
for script_class, path, basedir, script_module in auto_processing_scripts + scripts_data:
|
||||
for script_class, path, basedir, script_module in scripts_data:
|
||||
script = script_class()
|
||||
script.filename = path
|
||||
script.is_txt2img = not is_img2img
|
||||
|
@ -345,23 +320,9 @@ class ScriptRunner:
|
|||
outputs=[script.group for script in self.selectable_scripts]
|
||||
)
|
||||
|
||||
self.script_load_ctr = 0
|
||||
def onload_script_visibility(params):
|
||||
title = params.get('Script', None)
|
||||
if title:
|
||||
title_index = self.titles.index(title)
|
||||
visibility = title_index == self.script_load_ctr
|
||||
self.script_load_ctr = (self.script_load_ctr + 1) % len(self.titles)
|
||||
return gr.update(visible=visibility)
|
||||
else:
|
||||
return gr.update(visible=False)
|
||||
|
||||
self.infotext_fields.append( (dropdown, lambda x: gr.update(value=x.get('Script', 'None'))) )
|
||||
self.infotext_fields.extend( [(script.group, onload_script_visibility) for script in self.selectable_scripts] )
|
||||
|
||||
return inputs
|
||||
|
||||
def run(self, p, *args):
|
||||
def run(self, p: StableDiffusionProcessing, *args):
|
||||
script_index = args[0]
|
||||
|
||||
if script_index == 0:
|
||||
|
@ -415,15 +376,6 @@ class ScriptRunner:
|
|||
print(f"Error running postprocess_batch: {script.filename}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
def postprocess_image(self, p, pp: PostprocessImageArgs):
|
||||
for script in self.alwayson_scripts:
|
||||
try:
|
||||
script_args = p.script_args[script.args_from:script.args_to]
|
||||
script.postprocess_image(p, pp, *script_args)
|
||||
except Exception:
|
||||
print(f"Error running postprocess_batch: {script.filename}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
def before_component(self, component, **kwargs):
|
||||
for script in self.scripts:
|
||||
try:
|
||||
|
@ -461,7 +413,6 @@ class ScriptRunner:
|
|||
|
||||
scripts_txt2img = ScriptRunner()
|
||||
scripts_img2img = ScriptRunner()
|
||||
scripts_postproc = scripts_postprocessing.ScriptPostprocessingRunner()
|
||||
scripts_current: ScriptRunner = None
|
||||
|
||||
|
||||
|
@ -472,13 +423,12 @@ def reload_script_body_only():
|
|||
|
||||
|
||||
def reload_scripts():
|
||||
global scripts_txt2img, scripts_img2img, scripts_postproc
|
||||
global scripts_txt2img, scripts_img2img
|
||||
|
||||
load_scripts()
|
||||
|
||||
scripts_txt2img = ScriptRunner()
|
||||
scripts_img2img = ScriptRunner()
|
||||
scripts_postproc = scripts_postprocessing.ScriptPostprocessingRunner()
|
||||
|
||||
|
||||
def IOComponent_init(self, *args, **kwargs):
|
||||
|
|
|
@ -1,42 +0,0 @@
|
|||
from modules import scripts, scripts_postprocessing, shared
|
||||
|
||||
|
||||
class ScriptPostprocessingForMainUI(scripts.Script):
|
||||
def __init__(self, script_postproc):
|
||||
self.script: scripts_postprocessing.ScriptPostprocessing = script_postproc
|
||||
self.postprocessing_controls = None
|
||||
|
||||
def title(self):
|
||||
return self.script.name
|
||||
|
||||
def show(self, is_img2img):
|
||||
return scripts.AlwaysVisible
|
||||
|
||||
def ui(self, is_img2img):
|
||||
self.postprocessing_controls = self.script.ui()
|
||||
return self.postprocessing_controls.values()
|
||||
|
||||
def postprocess_image(self, p, script_pp, *args):
|
||||
args_dict = {k: v for k, v in zip(self.postprocessing_controls, args)}
|
||||
|
||||
pp = scripts_postprocessing.PostprocessedImage(script_pp.image)
|
||||
pp.info = {}
|
||||
self.script.process(pp, **args_dict)
|
||||
p.extra_generation_params.update(pp.info)
|
||||
script_pp.image = pp.image
|
||||
|
||||
|
||||
def create_auto_preprocessing_script_data():
|
||||
from modules import scripts
|
||||
|
||||
res = []
|
||||
|
||||
for name in shared.opts.postprocessing_enable_in_main_ui:
|
||||
script = next(iter([x for x in scripts.postprocessing_scripts_data if x.script_class.name == name]), None)
|
||||
if script is None:
|
||||
continue
|
||||
|
||||
constructor = lambda s=script: ScriptPostprocessingForMainUI(s.script_class())
|
||||
res.append(scripts.ScriptClassData(script_class=constructor, path=script.path, basedir=script.basedir, module=script.module))
|
||||
|
||||
return res
|
|
@ -1,152 +0,0 @@
|
|||
import os
|
||||
import gradio as gr
|
||||
|
||||
from modules import errors, shared
|
||||
|
||||
|
||||
class PostprocessedImage:
|
||||
def __init__(self, image):
|
||||
self.image = image
|
||||
self.info = {}
|
||||
|
||||
|
||||
class ScriptPostprocessing:
|
||||
filename = None
|
||||
controls = None
|
||||
args_from = None
|
||||
args_to = None
|
||||
|
||||
order = 1000
|
||||
"""scripts will be ordred by this value in postprocessing UI"""
|
||||
|
||||
name = None
|
||||
"""this function should return the title of the script."""
|
||||
|
||||
group = None
|
||||
"""A gr.Group component that has all script's UI inside it"""
|
||||
|
||||
def ui(self):
|
||||
"""
|
||||
This function should create gradio UI elements. See https://gradio.app/docs/#components
|
||||
The return value should be a dictionary that maps parameter names to components used in processing.
|
||||
Values of those components will be passed to process() function.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
def process(self, pp: PostprocessedImage, **args):
|
||||
"""
|
||||
This function is called to postprocess the image.
|
||||
args contains a dictionary with all values returned by components from ui()
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
def image_changed(self):
|
||||
pass
|
||||
|
||||
|
||||
|
||||
|
||||
def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
|
||||
try:
|
||||
res = func(*args, **kwargs)
|
||||
return res
|
||||
except Exception as e:
|
||||
errors.display(e, f"calling {filename}/{funcname}")
|
||||
|
||||
return default
|
||||
|
||||
|
||||
class ScriptPostprocessingRunner:
|
||||
def __init__(self):
|
||||
self.scripts = None
|
||||
self.ui_created = False
|
||||
|
||||
def initialize_scripts(self, scripts_data):
|
||||
self.scripts = []
|
||||
|
||||
for script_class, path, basedir, script_module in scripts_data:
|
||||
script: ScriptPostprocessing = script_class()
|
||||
script.filename = path
|
||||
|
||||
if script.name == "Simple Upscale":
|
||||
continue
|
||||
|
||||
self.scripts.append(script)
|
||||
|
||||
def create_script_ui(self, script, inputs):
|
||||
script.args_from = len(inputs)
|
||||
script.args_to = len(inputs)
|
||||
|
||||
script.controls = wrap_call(script.ui, script.filename, "ui")
|
||||
|
||||
for control in script.controls.values():
|
||||
control.custom_script_source = os.path.basename(script.filename)
|
||||
|
||||
inputs += list(script.controls.values())
|
||||
script.args_to = len(inputs)
|
||||
|
||||
def scripts_in_preferred_order(self):
|
||||
if self.scripts is None:
|
||||
import modules.scripts
|
||||
self.initialize_scripts(modules.scripts.postprocessing_scripts_data)
|
||||
|
||||
scripts_order = shared.opts.postprocessing_operation_order
|
||||
|
||||
def script_score(name):
|
||||
for i, possible_match in enumerate(scripts_order):
|
||||
if possible_match == name:
|
||||
return i
|
||||
|
||||
return len(self.scripts)
|
||||
|
||||
script_scores = {script.name: (script_score(script.name), script.order, script.name, original_index) for original_index, script in enumerate(self.scripts)}
|
||||
|
||||
return sorted(self.scripts, key=lambda x: script_scores[x.name])
|
||||
|
||||
def setup_ui(self):
|
||||
inputs = []
|
||||
|
||||
for script in self.scripts_in_preferred_order():
|
||||
with gr.Box() as group:
|
||||
self.create_script_ui(script, inputs)
|
||||
|
||||
script.group = group
|
||||
|
||||
self.ui_created = True
|
||||
return inputs
|
||||
|
||||
def run(self, pp: PostprocessedImage, args):
|
||||
for script in self.scripts_in_preferred_order():
|
||||
shared.state.job = script.name
|
||||
|
||||
script_args = args[script.args_from:script.args_to]
|
||||
|
||||
process_args = {}
|
||||
for (name, component), value in zip(script.controls.items(), script_args):
|
||||
process_args[name] = value
|
||||
|
||||
script.process(pp, **process_args)
|
||||
|
||||
def create_args_for_run(self, scripts_args):
|
||||
if not self.ui_created:
|
||||
with gr.Blocks(analytics_enabled=False):
|
||||
self.setup_ui()
|
||||
|
||||
scripts = self.scripts_in_preferred_order()
|
||||
args = [None] * max([x.args_to for x in scripts])
|
||||
|
||||
for script in scripts:
|
||||
script_args_dict = scripts_args.get(script.name, None)
|
||||
if script_args_dict is not None:
|
||||
|
||||
for i, name in enumerate(script.controls):
|
||||
args[script.args_from + i] = script_args_dict.get(name, None)
|
||||
|
||||
return args
|
||||
|
||||
def image_changed(self):
|
||||
for script in self.scripts_in_preferred_order():
|
||||
script.image_changed()
|
||||
|
|
@ -20,9 +20,8 @@ class DisableInitialization:
|
|||
```
|
||||
"""
|
||||
|
||||
def __init__(self, disable_clip=True):
|
||||
def __init__(self):
|
||||
self.replaced = []
|
||||
self.disable_clip = disable_clip
|
||||
|
||||
def replace(self, obj, field, func):
|
||||
original = getattr(obj, field, None)
|
||||
|
@ -42,9 +41,7 @@ class DisableInitialization:
|
|||
return self.create_model_and_transforms(*args, pretrained=None, **kwargs)
|
||||
|
||||
def CLIPTextModel_from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs):
|
||||
res = self.CLIPTextModel_from_pretrained(None, *model_args, config=pretrained_model_name_or_path, state_dict={}, **kwargs)
|
||||
res.name_or_path = pretrained_model_name_or_path
|
||||
return res
|
||||
return self.CLIPTextModel_from_pretrained(None, *model_args, config=pretrained_model_name_or_path, state_dict={}, **kwargs)
|
||||
|
||||
def transformers_modeling_utils_load_pretrained_model(*args, **kwargs):
|
||||
args = args[0:3] + ('/', ) + args[4:] # resolved_archive_file; must set it to something to prevent what seems to be a bug
|
||||
|
@ -76,14 +73,12 @@ class DisableInitialization:
|
|||
self.replace(torch.nn.init, 'kaiming_uniform_', do_nothing)
|
||||
self.replace(torch.nn.init, '_no_grad_normal_', do_nothing)
|
||||
self.replace(torch.nn.init, '_no_grad_uniform_', do_nothing)
|
||||
|
||||
if self.disable_clip:
|
||||
self.create_model_and_transforms = self.replace(open_clip, 'create_model_and_transforms', create_model_and_transforms_without_pretrained)
|
||||
self.CLIPTextModel_from_pretrained = self.replace(ldm.modules.encoders.modules.CLIPTextModel, 'from_pretrained', CLIPTextModel_from_pretrained)
|
||||
self.transformers_modeling_utils_load_pretrained_model = self.replace(transformers.modeling_utils.PreTrainedModel, '_load_pretrained_model', transformers_modeling_utils_load_pretrained_model)
|
||||
self.transformers_tokenization_utils_base_cached_file = self.replace(transformers.tokenization_utils_base, 'cached_file', transformers_tokenization_utils_base_cached_file)
|
||||
self.transformers_configuration_utils_cached_file = self.replace(transformers.configuration_utils, 'cached_file', transformers_configuration_utils_cached_file)
|
||||
self.transformers_utils_hub_get_from_cache = self.replace(transformers.utils.hub, 'get_from_cache', transformers_utils_hub_get_from_cache)
|
||||
self.create_model_and_transforms = self.replace(open_clip, 'create_model_and_transforms', create_model_and_transforms_without_pretrained)
|
||||
self.CLIPTextModel_from_pretrained = self.replace(ldm.modules.encoders.modules.CLIPTextModel, 'from_pretrained', CLIPTextModel_from_pretrained)
|
||||
self.transformers_modeling_utils_load_pretrained_model = self.replace(transformers.modeling_utils.PreTrainedModel, '_load_pretrained_model', transformers_modeling_utils_load_pretrained_model)
|
||||
self.transformers_tokenization_utils_base_cached_file = self.replace(transformers.tokenization_utils_base, 'cached_file', transformers_tokenization_utils_base_cached_file)
|
||||
self.transformers_configuration_utils_cached_file = self.replace(transformers.configuration_utils, 'cached_file', transformers_configuration_utils_cached_file)
|
||||
self.transformers_utils_hub_get_from_cache = self.replace(transformers.utils.hub, 'get_from_cache', transformers_utils_hub_get_from_cache)
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
for obj, field, original in self.replaced:
|
||||
|
|
|
@ -70,10 +70,9 @@ def undo_optimizations():
|
|||
|
||||
|
||||
def fix_checkpoint():
|
||||
"""checkpoints are now added and removed in embedding/hypernet code, since torch doesn't want
|
||||
checkpoints to be added when not training (there's a warning)"""
|
||||
|
||||
pass
|
||||
ldm.modules.attention.BasicTransformerBlock.forward = sd_hijack_checkpoint.BasicTransformerBlock_forward
|
||||
ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = sd_hijack_checkpoint.ResBlock_forward
|
||||
ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = sd_hijack_checkpoint.AttentionBlock_forward
|
||||
|
||||
|
||||
class StableDiffusionModelHijack:
|
||||
|
@ -107,6 +106,8 @@ class StableDiffusionModelHijack:
|
|||
self.optimization_method = apply_optimizations()
|
||||
|
||||
self.clip = m.cond_stage_model
|
||||
|
||||
fix_checkpoint()
|
||||
|
||||
def flatten(el):
|
||||
flattened = [flatten(children) for children in el.children()]
|
||||
|
@ -131,8 +132,6 @@ class StableDiffusionModelHijack:
|
|||
m.cond_stage_model.wrapped.model.token_embedding = m.cond_stage_model.wrapped.model.token_embedding.wrapped
|
||||
m.cond_stage_model = m.cond_stage_model.wrapped
|
||||
|
||||
undo_optimizations()
|
||||
|
||||
self.apply_circular(False)
|
||||
self.layers = None
|
||||
self.clip = None
|
||||
|
@ -173,7 +172,7 @@ class EmbeddingsWithFixes(torch.nn.Module):
|
|||
vecs = []
|
||||
for fixes, tensor in zip(batch_fixes, inputs_embeds):
|
||||
for offset, embedding in fixes:
|
||||
emb = devices.cond_cast_unet(embedding.vec)
|
||||
emb = embedding.vec
|
||||
emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0])
|
||||
tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]])
|
||||
|
||||
|
|
|
@ -1,46 +1,10 @@
|
|||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
import ldm.modules.attention
|
||||
import ldm.modules.diffusionmodules.openaimodel
|
||||
|
||||
|
||||
def BasicTransformerBlock_forward(self, x, context=None):
|
||||
return checkpoint(self._forward, x, context)
|
||||
|
||||
|
||||
def AttentionBlock_forward(self, x):
|
||||
return checkpoint(self._forward, x)
|
||||
|
||||
|
||||
def ResBlock_forward(self, x, emb):
|
||||
return checkpoint(self._forward, x, emb)
|
||||
|
||||
|
||||
stored = []
|
||||
|
||||
|
||||
def add():
|
||||
if len(stored) != 0:
|
||||
return
|
||||
|
||||
stored.extend([
|
||||
ldm.modules.attention.BasicTransformerBlock.forward,
|
||||
ldm.modules.diffusionmodules.openaimodel.ResBlock.forward,
|
||||
ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward
|
||||
])
|
||||
|
||||
ldm.modules.attention.BasicTransformerBlock.forward = BasicTransformerBlock_forward
|
||||
ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = ResBlock_forward
|
||||
ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = AttentionBlock_forward
|
||||
|
||||
|
||||
def remove():
|
||||
if len(stored) == 0:
|
||||
return
|
||||
|
||||
ldm.modules.attention.BasicTransformerBlock.forward = stored[0]
|
||||
ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = stored[1]
|
||||
ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = stored[2]
|
||||
|
||||
stored.clear()
|
||||
|
||||
return checkpoint(self._forward, x, emb)
|
|
@ -96,18 +96,13 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
|
|||
token_count = 0
|
||||
last_comma = -1
|
||||
|
||||
def next_chunk(is_last=False):
|
||||
"""puts current chunk into the list of results and produces the next one - empty;
|
||||
if is_last is true, tokens <end-of-text> tokens at the end won't add to token_count"""
|
||||
def next_chunk():
|
||||
"""puts current chunk into the list of results and produces the next one - empty"""
|
||||
nonlocal token_count
|
||||
nonlocal last_comma
|
||||
nonlocal chunk
|
||||
|
||||
if is_last:
|
||||
token_count += len(chunk.tokens)
|
||||
else:
|
||||
token_count += self.chunk_length
|
||||
|
||||
token_count += len(chunk.tokens)
|
||||
to_add = self.chunk_length - len(chunk.tokens)
|
||||
if to_add > 0:
|
||||
chunk.tokens += [self.id_end] * to_add
|
||||
|
@ -121,10 +116,6 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
|
|||
chunk = PromptChunk()
|
||||
|
||||
for tokens, (text, weight) in zip(tokenized, parsed):
|
||||
if text == 'BREAK' and weight == -1:
|
||||
next_chunk()
|
||||
continue
|
||||
|
||||
position = 0
|
||||
while position < len(tokens):
|
||||
token = tokens[position]
|
||||
|
@ -168,7 +159,7 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
|
|||
position += embedding_length_in_tokens
|
||||
|
||||
if len(chunk.tokens) > 0 or len(chunks) == 0:
|
||||
next_chunk(is_last=True)
|
||||
next_chunk()
|
||||
|
||||
return chunks, token_count
|
||||
|
||||
|
|
|
@ -96,6 +96,15 @@ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=F
|
|||
return x_prev, pred_x0, e_t
|
||||
|
||||
|
||||
def should_hijack_inpainting(checkpoint_info):
|
||||
from modules import sd_models
|
||||
|
||||
ckpt_basename = os.path.basename(checkpoint_info.filename).lower()
|
||||
cfg_basename = os.path.basename(sd_models.find_checkpoint_config(checkpoint_info)).lower()
|
||||
|
||||
return "inpainting" in ckpt_basename and not "inpainting" in cfg_basename
|
||||
|
||||
|
||||
def do_inpainting_hijack():
|
||||
# p_sample_plms is needed because PLMS can't work with dicts as conditionings
|
||||
|
||||
|
|
|
@ -1,13 +0,0 @@
|
|||
import collections
|
||||
import os.path
|
||||
import sys
|
||||
import gc
|
||||
import time
|
||||
|
||||
def should_hijack_ip2p(checkpoint_info):
|
||||
from modules import sd_models_config
|
||||
|
||||
ckpt_basename = os.path.basename(checkpoint_info.filename).lower()
|
||||
cfg_basename = os.path.basename(sd_models_config.find_checkpoint_config_near_filename(checkpoint_info)).lower()
|
||||
|
||||
return "pix2pix" in ckpt_basename and not "pix2pix" in cfg_basename
|
|
@ -9,7 +9,7 @@ from torch import einsum
|
|||
from ldm.util import default
|
||||
from einops import rearrange
|
||||
|
||||
from modules import shared, errors, devices
|
||||
from modules import shared
|
||||
from modules.hypernetworks import hypernetwork
|
||||
|
||||
from .sub_quadratic_attention import efficient_dot_product_attention
|
||||
|
@ -44,7 +44,7 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None):
|
|||
q_in = self.to_q(x)
|
||||
context = default(context, x)
|
||||
|
||||
context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
|
||||
context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
|
||||
k_in = self.to_k(context_k)
|
||||
v_in = self.to_v(context_v)
|
||||
del context, context_k, context_v, x
|
||||
|
@ -52,25 +52,18 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None):
|
|||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
|
||||
del q_in, k_in, v_in
|
||||
|
||||
dtype = q.dtype
|
||||
if shared.opts.upcast_attn:
|
||||
q, k, v = q.float(), k.float(), v.float()
|
||||
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
|
||||
for i in range(0, q.shape[0], 2):
|
||||
end = i + 2
|
||||
s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
|
||||
s1 *= self.scale
|
||||
|
||||
with devices.without_autocast(disable=not shared.opts.upcast_attn):
|
||||
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||
for i in range(0, q.shape[0], 2):
|
||||
end = i + 2
|
||||
s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
|
||||
s1 *= self.scale
|
||||
|
||||
s2 = s1.softmax(dim=-1)
|
||||
del s1
|
||||
|
||||
r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
|
||||
del s2
|
||||
del q, k, v
|
||||
s2 = s1.softmax(dim=-1)
|
||||
del s1
|
||||
|
||||
r1 = r1.to(dtype)
|
||||
r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
|
||||
del s2
|
||||
del q, k, v
|
||||
|
||||
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
|
||||
del r1
|
||||
|
@ -85,56 +78,49 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
|
|||
q_in = self.to_q(x)
|
||||
context = default(context, x)
|
||||
|
||||
context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
|
||||
context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
|
||||
k_in = self.to_k(context_k)
|
||||
v_in = self.to_v(context_v)
|
||||
|
||||
dtype = q_in.dtype
|
||||
if shared.opts.upcast_attn:
|
||||
q_in, k_in, v_in = q_in.float(), k_in.float(), v_in if v_in.device.type == 'mps' else v_in.float()
|
||||
k_in *= self.scale
|
||||
|
||||
with devices.without_autocast(disable=not shared.opts.upcast_attn):
|
||||
k_in = k_in * self.scale
|
||||
|
||||
del context, x
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
|
||||
del q_in, k_in, v_in
|
||||
|
||||
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||
|
||||
mem_free_total = get_available_vram()
|
||||
|
||||
gb = 1024 ** 3
|
||||
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
|
||||
modifier = 3 if q.element_size() == 2 else 2.5
|
||||
mem_required = tensor_size * modifier
|
||||
steps = 1
|
||||
|
||||
if mem_required > mem_free_total:
|
||||
steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
|
||||
# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
|
||||
# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
|
||||
|
||||
if steps > 64:
|
||||
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
|
||||
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
|
||||
f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
|
||||
|
||||
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
||||
for i in range(0, q.shape[1], slice_size):
|
||||
end = i + slice_size
|
||||
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
|
||||
|
||||
s2 = s1.softmax(dim=-1, dtype=q.dtype)
|
||||
del s1
|
||||
|
||||
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
|
||||
del s2
|
||||
|
||||
del q, k, v
|
||||
del context, x
|
||||
|
||||
r1 = r1.to(dtype)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
|
||||
del q_in, k_in, v_in
|
||||
|
||||
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||
|
||||
mem_free_total = get_available_vram()
|
||||
|
||||
gb = 1024 ** 3
|
||||
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
|
||||
modifier = 3 if q.element_size() == 2 else 2.5
|
||||
mem_required = tensor_size * modifier
|
||||
steps = 1
|
||||
|
||||
if mem_required > mem_free_total:
|
||||
steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
|
||||
# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
|
||||
# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
|
||||
|
||||
if steps > 64:
|
||||
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
|
||||
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
|
||||
f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
|
||||
|
||||
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
||||
for i in range(0, q.shape[1], slice_size):
|
||||
end = i + slice_size
|
||||
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
|
||||
|
||||
s2 = s1.softmax(dim=-1, dtype=q.dtype)
|
||||
del s1
|
||||
|
||||
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
|
||||
del s2
|
||||
|
||||
del q, k, v
|
||||
|
||||
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
|
||||
del r1
|
||||
|
@ -217,21 +203,13 @@ def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None):
|
|||
q = self.to_q(x)
|
||||
context = default(context, x)
|
||||
|
||||
context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
|
||||
k = self.to_k(context_k)
|
||||
context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
|
||||
k = self.to_k(context_k) * self.scale
|
||||
v = self.to_v(context_v)
|
||||
del context, context_k, context_v, x
|
||||
|
||||
dtype = q.dtype
|
||||
if shared.opts.upcast_attn:
|
||||
q, k, v = q.float(), k.float(), v if v.device.type == 'mps' else v.float()
|
||||
|
||||
with devices.without_autocast(disable=not shared.opts.upcast_attn):
|
||||
k = k * self.scale
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
||||
r = einsum_op(q, k, v)
|
||||
r = r.to(dtype)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
||||
r = einsum_op(q, k, v)
|
||||
return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h))
|
||||
|
||||
# -- End of code from https://github.com/invoke-ai/InvokeAI --
|
||||
|
@ -247,7 +225,7 @@ def sub_quad_attention_forward(self, x, context=None, mask=None):
|
|||
q = self.to_q(x)
|
||||
context = default(context, x)
|
||||
|
||||
context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
|
||||
context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
|
||||
k = self.to_k(context_k)
|
||||
v = self.to_v(context_v)
|
||||
del context, context_k, context_v, x
|
||||
|
@ -256,14 +234,8 @@ def sub_quad_attention_forward(self, x, context=None, mask=None):
|
|||
k = k.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
|
||||
v = v.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
|
||||
|
||||
dtype = q.dtype
|
||||
if shared.opts.upcast_attn:
|
||||
q, k = q.float(), k.float()
|
||||
|
||||
x = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training)
|
||||
|
||||
x = x.to(dtype)
|
||||
|
||||
x = x.unflatten(0, (-1, h)).transpose(1,2).flatten(start_dim=2)
|
||||
|
||||
out_proj, dropout = self.to_out
|
||||
|
@ -296,31 +268,15 @@ def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_
|
|||
query_chunk_size = q_tokens
|
||||
kv_chunk_size = k_tokens
|
||||
|
||||
with devices.without_autocast(disable=q.dtype == v.dtype):
|
||||
return efficient_dot_product_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
query_chunk_size=q_chunk_size,
|
||||
kv_chunk_size=kv_chunk_size,
|
||||
kv_chunk_size_min = kv_chunk_size_min,
|
||||
use_checkpoint=use_checkpoint,
|
||||
)
|
||||
|
||||
|
||||
def get_xformers_flash_attention_op(q, k, v):
|
||||
if not shared.cmd_opts.xformers_flash_attention:
|
||||
return None
|
||||
|
||||
try:
|
||||
flash_attention_op = xformers.ops.MemoryEfficientAttentionFlashAttentionOp
|
||||
fw, bw = flash_attention_op
|
||||
if fw.supports(xformers.ops.fmha.Inputs(query=q, key=k, value=v, attn_bias=None)):
|
||||
return flash_attention_op
|
||||
except Exception as e:
|
||||
errors.display_once(e, "enabling flash attention")
|
||||
|
||||
return None
|
||||
return efficient_dot_product_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
query_chunk_size=q_chunk_size,
|
||||
kv_chunk_size=kv_chunk_size,
|
||||
kv_chunk_size_min = kv_chunk_size_min,
|
||||
use_checkpoint=use_checkpoint,
|
||||
)
|
||||
|
||||
|
||||
def xformers_attention_forward(self, x, context=None, mask=None):
|
||||
|
@ -328,20 +284,13 @@ def xformers_attention_forward(self, x, context=None, mask=None):
|
|||
q_in = self.to_q(x)
|
||||
context = default(context, x)
|
||||
|
||||
context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
|
||||
context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
|
||||
k_in = self.to_k(context_k)
|
||||
v_in = self.to_v(context_v)
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in))
|
||||
del q_in, k_in, v_in
|
||||
|
||||
dtype = q.dtype
|
||||
if shared.opts.upcast_attn:
|
||||
q, k = q.float(), k.float()
|
||||
|
||||
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=get_xformers_flash_attention_op(q, k, v))
|
||||
|
||||
out = out.to(dtype)
|
||||
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None)
|
||||
|
||||
out = rearrange(out, 'b n h d -> b n (h d)', h=h)
|
||||
return self.to_out(out)
|
||||
|
@ -413,14 +362,10 @@ def xformers_attnblock_forward(self, x):
|
|||
v = self.v(h_)
|
||||
b, c, h, w = q.shape
|
||||
q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v))
|
||||
dtype = q.dtype
|
||||
if shared.opts.upcast_attn:
|
||||
q, k = q.float(), k.float()
|
||||
q = q.contiguous()
|
||||
k = k.contiguous()
|
||||
v = v.contiguous()
|
||||
out = xformers.ops.memory_efficient_attention(q, k, v, op=get_xformers_flash_attention_op(q, k, v))
|
||||
out = out.to(dtype)
|
||||
out = xformers.ops.memory_efficient_attention(q, k, v)
|
||||
out = rearrange(out, 'b (h w) c -> b c h w', h=h)
|
||||
out = self.proj_out(out)
|
||||
return x + out
|
||||
|
|
|
@ -1,8 +1,4 @@
|
|||
import torch
|
||||
from packaging import version
|
||||
|
||||
from modules import devices
|
||||
from modules.sd_hijack_utils import CondFunc
|
||||
|
||||
|
||||
class TorchHijackForUnet:
|
||||
|
@ -32,37 +28,3 @@ class TorchHijackForUnet:
|
|||
|
||||
|
||||
th = TorchHijackForUnet()
|
||||
|
||||
|
||||
# Below are monkey patches to enable upcasting a float16 UNet for float32 sampling
|
||||
def apply_model(orig_func, self, x_noisy, t, cond, **kwargs):
|
||||
|
||||
if isinstance(cond, dict):
|
||||
for y in cond.keys():
|
||||
cond[y] = [x.to(devices.dtype_unet) if isinstance(x, torch.Tensor) else x for x in cond[y]]
|
||||
|
||||
with devices.autocast():
|
||||
return orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs).float()
|
||||
|
||||
class GELUHijack(torch.nn.GELU, torch.nn.Module):
|
||||
def __init__(self, *args, **kwargs):
|
||||
torch.nn.GELU.__init__(self, *args, **kwargs)
|
||||
def forward(self, x):
|
||||
if devices.unet_needs_upcast:
|
||||
return torch.nn.GELU.forward(self.float(), x.float()).to(devices.dtype_unet)
|
||||
else:
|
||||
return torch.nn.GELU.forward(self, x)
|
||||
|
||||
unet_needs_upcast = lambda *args, **kwargs: devices.unet_needs_upcast
|
||||
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model, unet_needs_upcast)
|
||||
CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast)
|
||||
if version.parse(torch.__version__) <= version.parse("1.13.1"):
|
||||
CondFunc('ldm.modules.diffusionmodules.util.GroupNorm32.forward', lambda orig_func, self, *args, **kwargs: orig_func(self.float(), *args, **kwargs), unet_needs_upcast)
|
||||
CondFunc('ldm.modules.attention.GEGLU.forward', lambda orig_func, self, x: orig_func(self.float(), x.float()).to(devices.dtype_unet), unet_needs_upcast)
|
||||
CondFunc('open_clip.transformer.ResidualAttentionBlock.__init__', lambda orig_func, *args, **kwargs: kwargs.update({'act_layer': GELUHijack}) and False or orig_func(*args, **kwargs), lambda _, *args, **kwargs: kwargs.get('act_layer') is None or kwargs['act_layer'] == torch.nn.GELU)
|
||||
|
||||
first_stage_cond = lambda _, self, *args, **kwargs: devices.unet_needs_upcast and self.model.diffusion_model.dtype == torch.float16
|
||||
first_stage_sub = lambda orig_func, self, x, **kwargs: orig_func(self, x.to(devices.dtype_vae), **kwargs)
|
||||
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond)
|
||||
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond)
|
||||
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.get_first_stage_encoding', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).float(), first_stage_cond)
|
||||
|
|
|
@ -1,28 +0,0 @@
|
|||
import importlib
|
||||
|
||||
class CondFunc:
|
||||
def __new__(cls, orig_func, sub_func, cond_func):
|
||||
self = super(CondFunc, cls).__new__(cls)
|
||||
if isinstance(orig_func, str):
|
||||
func_path = orig_func.split('.')
|
||||
for i in range(len(func_path)-1, -1, -1):
|
||||
try:
|
||||
resolved_obj = importlib.import_module('.'.join(func_path[:i]))
|
||||
break
|
||||
except ImportError:
|
||||
pass
|
||||
for attr_name in func_path[i:-1]:
|
||||
resolved_obj = getattr(resolved_obj, attr_name)
|
||||
orig_func = getattr(resolved_obj, func_path[-1])
|
||||
setattr(resolved_obj, func_path[-1], lambda *args, **kwargs: self(*args, **kwargs))
|
||||
self.__init__(orig_func, sub_func, cond_func)
|
||||
return lambda *args, **kwargs: self(*args, **kwargs)
|
||||
def __init__(self, orig_func, sub_func, cond_func):
|
||||
self.__orig_func = orig_func
|
||||
self.__sub_func = sub_func
|
||||
self.__cond_func = cond_func
|
||||
def __call__(self, *args, **kwargs):
|
||||
if not self.__cond_func or self.__cond_func(self.__orig_func, *args, **kwargs):
|
||||
return self.__sub_func(self.__orig_func, *args, **kwargs)
|
||||
else:
|
||||
return self.__orig_func(*args, **kwargs)
|
|
@ -2,6 +2,8 @@ import collections
|
|||
import os.path
|
||||
import sys
|
||||
import gc
|
||||
import time
|
||||
from collections import namedtuple
|
||||
import torch
|
||||
import re
|
||||
import safetensors.torch
|
||||
|
@ -12,13 +14,12 @@ import ldm.modules.midas as midas
|
|||
|
||||
from ldm.util import instantiate_from_config
|
||||
|
||||
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config
|
||||
from modules import shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes
|
||||
from modules.paths import models_path
|
||||
from modules.sd_hijack_inpainting import do_inpainting_hijack
|
||||
from modules.timer import Timer
|
||||
from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting
|
||||
|
||||
model_dir = "Stable-diffusion"
|
||||
model_path = os.path.abspath(os.path.join(paths.models_path, model_dir))
|
||||
model_path = os.path.abspath(os.path.join(models_path, model_dir))
|
||||
|
||||
checkpoints_list = {}
|
||||
checkpoint_alisases = {}
|
||||
|
@ -40,17 +41,14 @@ class CheckpointInfo:
|
|||
if name.startswith("\\") or name.startswith("/"):
|
||||
name = name[1:]
|
||||
|
||||
self.name = name
|
||||
self.name_for_extra = os.path.splitext(os.path.basename(filename))[0]
|
||||
self.title = name
|
||||
self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]
|
||||
self.hash = model_hash(filename)
|
||||
|
||||
self.sha256 = hashes.sha256_from_cache(self.filename, "checkpoint/" + name)
|
||||
self.sha256 = hashes.sha256_from_cache(self.filename, "checkpoint/" + self.title)
|
||||
self.shorthash = self.sha256[0:10] if self.sha256 else None
|
||||
|
||||
self.title = name if self.shorthash is None else f'{name} [{self.shorthash}]'
|
||||
|
||||
self.ids = [self.hash, self.model_name, self.title, name, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else [])
|
||||
self.ids = [self.hash, self.model_name, self.title, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256] if self.shorthash else [])
|
||||
|
||||
def register(self):
|
||||
checkpoints_list[self.title] = self
|
||||
|
@ -58,18 +56,12 @@ class CheckpointInfo:
|
|||
checkpoint_alisases[id] = self
|
||||
|
||||
def calculate_shorthash(self):
|
||||
self.sha256 = hashes.sha256(self.filename, "checkpoint/" + self.name)
|
||||
if self.sha256 is None:
|
||||
return
|
||||
|
||||
self.sha256 = hashes.sha256(self.filename, "checkpoint/" + self.title)
|
||||
self.shorthash = self.sha256[0:10]
|
||||
|
||||
if self.shorthash not in self.ids:
|
||||
self.ids += [self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]']
|
||||
|
||||
checkpoints_list.pop(self.title)
|
||||
self.title = f'{self.name} [{self.shorthash}]'
|
||||
self.register()
|
||||
self.ids += [self.shorthash, self.sha256]
|
||||
self.register()
|
||||
|
||||
return self.shorthash
|
||||
|
||||
|
@ -102,6 +94,17 @@ def checkpoint_tiles():
|
|||
return sorted([x.title for x in checkpoints_list.values()], key=alphanumeric_key)
|
||||
|
||||
|
||||
def find_checkpoint_config(info):
|
||||
if info is None:
|
||||
return shared.cmd_opts.config
|
||||
|
||||
config = os.path.splitext(info.filename)[0] + ".yaml"
|
||||
if os.path.exists(config):
|
||||
return config
|
||||
|
||||
return shared.cmd_opts.config
|
||||
|
||||
|
||||
def list_models():
|
||||
checkpoints_list.clear()
|
||||
checkpoint_alisases.clear()
|
||||
|
@ -162,7 +165,7 @@ def select_checkpoint():
|
|||
print(f" - directory {model_path}", file=sys.stderr)
|
||||
if shared.cmd_opts.ckpt_dir is not None:
|
||||
print(f" - directory {os.path.abspath(shared.cmd_opts.ckpt_dir)}", file=sys.stderr)
|
||||
print("Can't run without a checkpoint. Find and place a .ckpt or .safetensors file into any of those locations. The program will exit.", file=sys.stderr)
|
||||
print("Can't run without a checkpoint. Find and place a .ckpt file into any of those locations. The program will exit.", file=sys.stderr)
|
||||
exit(1)
|
||||
|
||||
checkpoint_info = next(iter(checkpoints_list.values()))
|
||||
|
@ -207,7 +210,9 @@ def get_state_dict_from_checkpoint(pl_sd):
|
|||
def read_state_dict(checkpoint_file, print_global_state=False, map_location=None):
|
||||
_, extension = os.path.splitext(checkpoint_file)
|
||||
if extension.lower() == ".safetensors":
|
||||
device = map_location or shared.weight_load_location or devices.get_optimal_device()
|
||||
device = map_location or shared.weight_load_location
|
||||
if device is None:
|
||||
device = devices.get_cuda_device_string() if torch.cuda.is_available() else "cpu"
|
||||
pl_sd = safetensors.torch.load_file(checkpoint_file, device=device)
|
||||
else:
|
||||
pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location)
|
||||
|
@ -219,72 +224,49 @@ def read_state_dict(checkpoint_file, print_global_state=False, map_location=None
|
|||
return sd
|
||||
|
||||
|
||||
def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer):
|
||||
def load_model_weights(model, checkpoint_info: CheckpointInfo):
|
||||
sd_model_hash = checkpoint_info.calculate_shorthash()
|
||||
timer.record("calculate hash")
|
||||
|
||||
if checkpoint_info in checkpoints_loaded:
|
||||
cache_enabled = shared.opts.sd_checkpoint_cache > 0
|
||||
|
||||
if cache_enabled and checkpoint_info in checkpoints_loaded:
|
||||
# use checkpoint cache
|
||||
print(f"Loading weights [{sd_model_hash}] from cache")
|
||||
return checkpoints_loaded[checkpoint_info]
|
||||
model.load_state_dict(checkpoints_loaded[checkpoint_info])
|
||||
else:
|
||||
# load from file
|
||||
print(f"Loading weights [{sd_model_hash}] from {checkpoint_info.filename}")
|
||||
|
||||
print(f"Loading weights [{sd_model_hash}] from {checkpoint_info.filename}")
|
||||
res = read_state_dict(checkpoint_info.filename)
|
||||
timer.record("load weights from disk")
|
||||
sd = read_state_dict(checkpoint_info.filename)
|
||||
model.load_state_dict(sd, strict=False)
|
||||
del sd
|
||||
|
||||
if cache_enabled:
|
||||
# cache newly loaded model
|
||||
checkpoints_loaded[checkpoint_info] = model.state_dict().copy()
|
||||
|
||||
return res
|
||||
if shared.cmd_opts.opt_channelslast:
|
||||
model.to(memory_format=torch.channels_last)
|
||||
|
||||
if not shared.cmd_opts.no_half:
|
||||
vae = model.first_stage_model
|
||||
|
||||
def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer):
|
||||
sd_model_hash = checkpoint_info.calculate_shorthash()
|
||||
timer.record("calculate hash")
|
||||
# with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16
|
||||
if shared.cmd_opts.no_half_vae:
|
||||
model.first_stage_model = None
|
||||
|
||||
shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title
|
||||
model.half()
|
||||
model.first_stage_model = vae
|
||||
|
||||
if state_dict is None:
|
||||
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
|
||||
devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16
|
||||
devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16
|
||||
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
del state_dict
|
||||
timer.record("apply weights to model")
|
||||
|
||||
if shared.opts.sd_checkpoint_cache > 0:
|
||||
# cache newly loaded model
|
||||
checkpoints_loaded[checkpoint_info] = model.state_dict().copy()
|
||||
|
||||
if shared.cmd_opts.opt_channelslast:
|
||||
model.to(memory_format=torch.channels_last)
|
||||
timer.record("apply channels_last")
|
||||
|
||||
if not shared.cmd_opts.no_half:
|
||||
vae = model.first_stage_model
|
||||
depth_model = getattr(model, 'depth_model', None)
|
||||
|
||||
# with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16
|
||||
if shared.cmd_opts.no_half_vae:
|
||||
model.first_stage_model = None
|
||||
# with --upcast-sampling, don't convert the depth model weights to float16
|
||||
if shared.cmd_opts.upcast_sampling and depth_model:
|
||||
model.depth_model = None
|
||||
|
||||
model.half()
|
||||
model.first_stage_model = vae
|
||||
if depth_model:
|
||||
model.depth_model = depth_model
|
||||
|
||||
timer.record("apply half()")
|
||||
|
||||
devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16
|
||||
devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16
|
||||
devices.dtype_unet = model.model.diffusion_model.dtype
|
||||
devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16
|
||||
|
||||
model.first_stage_model.to(devices.dtype_vae)
|
||||
timer.record("apply dtype to VAE")
|
||||
model.first_stage_model.to(devices.dtype_vae)
|
||||
|
||||
# clean up cache if limit is reached
|
||||
while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:
|
||||
checkpoints_loaded.popitem(last=False)
|
||||
if cache_enabled:
|
||||
while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache + 1: # we need to count the current model
|
||||
checkpoints_loaded.popitem(last=False) # LRU
|
||||
|
||||
model.sd_model_hash = sd_model_hash
|
||||
model.sd_model_checkpoint = checkpoint_info.filename
|
||||
|
@ -297,7 +279,6 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
|
|||
sd_vae.clear_loaded_vae()
|
||||
vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename)
|
||||
sd_vae.load_vae(model, vae_file, vae_source)
|
||||
timer.record("load VAE")
|
||||
|
||||
|
||||
def enable_midas_autodownload():
|
||||
|
@ -310,7 +291,7 @@ def enable_midas_autodownload():
|
|||
location automatically.
|
||||
"""
|
||||
|
||||
midas_path = os.path.join(paths.models_path, 'midas')
|
||||
midas_path = os.path.join(models_path, 'midas')
|
||||
|
||||
# stable-diffusion-stability-ai hard-codes the midas model path to
|
||||
# a location that differs from where other scripts using this model look.
|
||||
|
@ -343,23 +324,24 @@ def enable_midas_autodownload():
|
|||
midas.api.load_model = load_model_wrapper
|
||||
|
||||
|
||||
def repair_config(sd_config):
|
||||
class Timer:
|
||||
def __init__(self):
|
||||
self.start = time.time()
|
||||
|
||||
if not hasattr(sd_config.model.params, "use_ema"):
|
||||
sd_config.model.params.use_ema = False
|
||||
|
||||
if shared.cmd_opts.no_half:
|
||||
sd_config.model.params.unet_config.params.use_fp16 = False
|
||||
elif shared.cmd_opts.upcast_sampling:
|
||||
sd_config.model.params.unet_config.params.use_fp16 = True
|
||||
def elapsed(self):
|
||||
end = time.time()
|
||||
res = end - self.start
|
||||
self.start = end
|
||||
return res
|
||||
|
||||
|
||||
sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight'
|
||||
sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight'
|
||||
|
||||
def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_to_load_state_dict=None):
|
||||
def load_model(checkpoint_info=None):
|
||||
from modules import lowvram, sd_hijack
|
||||
checkpoint_info = checkpoint_info or select_checkpoint()
|
||||
checkpoint_config = find_checkpoint_config(checkpoint_info)
|
||||
|
||||
if checkpoint_config != shared.cmd_opts.config:
|
||||
print(f"Loading config from: {checkpoint_config}")
|
||||
|
||||
if shared.sd_model:
|
||||
sd_hijack.model_hijack.undo_hijack(shared.sd_model)
|
||||
|
@ -367,30 +349,29 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_
|
|||
gc.collect()
|
||||
devices.torch_gc()
|
||||
|
||||
sd_config = OmegaConf.load(checkpoint_config)
|
||||
|
||||
if should_hijack_inpainting(checkpoint_info):
|
||||
# Hardcoded config for now...
|
||||
sd_config.model.target = "ldm.models.diffusion.ddpm.LatentInpaintDiffusion"
|
||||
sd_config.model.params.conditioning_key = "hybrid"
|
||||
sd_config.model.params.unet_config.params.in_channels = 9
|
||||
sd_config.model.params.finetune_keys = None
|
||||
|
||||
if not hasattr(sd_config.model.params, "use_ema"):
|
||||
sd_config.model.params.use_ema = False
|
||||
|
||||
do_inpainting_hijack()
|
||||
|
||||
if shared.cmd_opts.no_half:
|
||||
sd_config.model.params.unet_config.params.use_fp16 = False
|
||||
|
||||
timer = Timer()
|
||||
|
||||
if already_loaded_state_dict is not None:
|
||||
state_dict = already_loaded_state_dict
|
||||
else:
|
||||
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
|
||||
|
||||
checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
|
||||
clip_is_included_into_sd = sd1_clip_weight in state_dict or sd2_clip_weight in state_dict
|
||||
|
||||
timer.record("find config")
|
||||
|
||||
sd_config = OmegaConf.load(checkpoint_config)
|
||||
repair_config(sd_config)
|
||||
|
||||
timer.record("load config")
|
||||
|
||||
print(f"Creating model from config: {checkpoint_config}")
|
||||
|
||||
sd_model = None
|
||||
|
||||
try:
|
||||
with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd):
|
||||
with sd_disable_initialization.DisableInitialization():
|
||||
sd_model = instantiate_from_config(sd_config.model)
|
||||
except Exception as e:
|
||||
pass
|
||||
|
@ -399,35 +380,29 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_
|
|||
print('Failed to create model quickly; will retry using slow method.', file=sys.stderr)
|
||||
sd_model = instantiate_from_config(sd_config.model)
|
||||
|
||||
sd_model.used_config = checkpoint_config
|
||||
elapsed_create = timer.elapsed()
|
||||
|
||||
timer.record("create model")
|
||||
load_model_weights(sd_model, checkpoint_info)
|
||||
|
||||
load_model_weights(sd_model, checkpoint_info, state_dict, timer)
|
||||
elapsed_load_weights = timer.elapsed()
|
||||
|
||||
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
||||
lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram)
|
||||
else:
|
||||
sd_model.to(shared.device)
|
||||
|
||||
timer.record("move model to device")
|
||||
|
||||
sd_hijack.model_hijack.hijack(sd_model)
|
||||
|
||||
timer.record("hijack")
|
||||
|
||||
sd_model.eval()
|
||||
shared.sd_model = sd_model
|
||||
|
||||
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model
|
||||
|
||||
timer.record("load textual inversion embeddings")
|
||||
|
||||
script_callbacks.model_loaded_callback(sd_model)
|
||||
|
||||
timer.record("scripts callbacks")
|
||||
elapsed_the_rest = timer.elapsed()
|
||||
|
||||
print(f"Model loaded in {timer.summary()}.")
|
||||
print(f"Model loaded in {elapsed_create + elapsed_load_weights + elapsed_the_rest:.1f}s ({elapsed_create:.1f}s create model, {elapsed_load_weights:.1f}s load weights).")
|
||||
|
||||
return sd_model
|
||||
|
||||
|
@ -438,7 +413,6 @@ def reload_model_weights(sd_model=None, info=None):
|
|||
|
||||
if not sd_model:
|
||||
sd_model = shared.sd_model
|
||||
|
||||
if sd_model is None: # previous model load failed
|
||||
current_checkpoint_info = None
|
||||
else:
|
||||
|
@ -446,44 +420,38 @@ def reload_model_weights(sd_model=None, info=None):
|
|||
if sd_model.sd_model_checkpoint == checkpoint_info.filename:
|
||||
return
|
||||
|
||||
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
||||
lowvram.send_everything_to_cpu()
|
||||
else:
|
||||
sd_model.to(devices.cpu)
|
||||
checkpoint_config = find_checkpoint_config(current_checkpoint_info)
|
||||
|
||||
sd_hijack.model_hijack.undo_hijack(sd_model)
|
||||
if current_checkpoint_info is None or checkpoint_config != find_checkpoint_config(checkpoint_info) or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info):
|
||||
del sd_model
|
||||
checkpoints_loaded.clear()
|
||||
load_model(checkpoint_info)
|
||||
return shared.sd_model
|
||||
|
||||
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
||||
lowvram.send_everything_to_cpu()
|
||||
else:
|
||||
sd_model.to(devices.cpu)
|
||||
|
||||
sd_hijack.model_hijack.undo_hijack(sd_model)
|
||||
|
||||
timer = Timer()
|
||||
|
||||
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
|
||||
|
||||
checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
|
||||
|
||||
timer.record("find config")
|
||||
|
||||
if sd_model is None or checkpoint_config != sd_model.used_config:
|
||||
del sd_model
|
||||
checkpoints_loaded.clear()
|
||||
load_model(checkpoint_info, already_loaded_state_dict=state_dict, time_taken_to_load_state_dict=timer.records["load weights from disk"])
|
||||
return shared.sd_model
|
||||
|
||||
try:
|
||||
load_model_weights(sd_model, checkpoint_info, state_dict, timer)
|
||||
load_model_weights(sd_model, checkpoint_info)
|
||||
except Exception as e:
|
||||
print("Failed to load checkpoint, restoring previous")
|
||||
load_model_weights(sd_model, current_checkpoint_info, None, timer)
|
||||
load_model_weights(sd_model, current_checkpoint_info)
|
||||
raise
|
||||
finally:
|
||||
sd_hijack.model_hijack.hijack(sd_model)
|
||||
timer.record("hijack")
|
||||
|
||||
script_callbacks.model_loaded_callback(sd_model)
|
||||
timer.record("script callbacks")
|
||||
|
||||
if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
|
||||
sd_model.to(devices.device)
|
||||
timer.record("move model to device")
|
||||
|
||||
print(f"Weights loaded in {timer.summary()}.")
|
||||
elapsed = timer.elapsed()
|
||||
|
||||
print(f"Weights loaded in {elapsed:.1f}s.")
|
||||
|
||||
return sd_model
|
||||
|
|
|
@ -1,112 +0,0 @@
|
|||
import re
|
||||
import os
|
||||
|
||||
import torch
|
||||
|
||||
from modules import shared, paths, sd_disable_initialization
|
||||
|
||||
sd_configs_path = shared.sd_configs_path
|
||||
sd_repo_configs_path = os.path.join(paths.paths['Stable Diffusion'], "configs", "stable-diffusion")
|
||||
|
||||
|
||||
config_default = shared.sd_default_config
|
||||
config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference.yaml")
|
||||
config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml")
|
||||
config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml")
|
||||
config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml")
|
||||
config_inpainting = os.path.join(sd_configs_path, "v1-inpainting-inference.yaml")
|
||||
config_instruct_pix2pix = os.path.join(sd_configs_path, "instruct-pix2pix.yaml")
|
||||
config_alt_diffusion = os.path.join(sd_configs_path, "alt-diffusion-inference.yaml")
|
||||
|
||||
|
||||
def is_using_v_parameterization_for_sd2(state_dict):
|
||||
"""
|
||||
Detects whether unet in state_dict is using v-parameterization. Returns True if it is. You're welcome.
|
||||
"""
|
||||
|
||||
import ldm.modules.diffusionmodules.openaimodel
|
||||
from modules import devices
|
||||
|
||||
device = devices.cpu
|
||||
|
||||
with sd_disable_initialization.DisableInitialization():
|
||||
unet = ldm.modules.diffusionmodules.openaimodel.UNetModel(
|
||||
use_checkpoint=True,
|
||||
use_fp16=False,
|
||||
image_size=32,
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
model_channels=320,
|
||||
attention_resolutions=[4, 2, 1],
|
||||
num_res_blocks=2,
|
||||
channel_mult=[1, 2, 4, 4],
|
||||
num_head_channels=64,
|
||||
use_spatial_transformer=True,
|
||||
use_linear_in_transformer=True,
|
||||
transformer_depth=1,
|
||||
context_dim=1024,
|
||||
legacy=False
|
||||
)
|
||||
unet.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
unet_sd = {k.replace("model.diffusion_model.", ""): v for k, v in state_dict.items() if "model.diffusion_model." in k}
|
||||
unet.load_state_dict(unet_sd, strict=True)
|
||||
unet.to(device=device, dtype=torch.float)
|
||||
|
||||
test_cond = torch.ones((1, 2, 1024), device=device) * 0.5
|
||||
x_test = torch.ones((1, 4, 8, 8), device=device) * 0.5
|
||||
|
||||
out = (unet(x_test, torch.asarray([999], device=device), context=test_cond) - x_test).mean().item()
|
||||
|
||||
return out < -1
|
||||
|
||||
|
||||
def guess_model_config_from_state_dict(sd, filename):
|
||||
sd2_cond_proj_weight = sd.get('cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight', None)
|
||||
diffusion_model_input = sd.get('model.diffusion_model.input_blocks.0.0.weight', None)
|
||||
|
||||
if sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None:
|
||||
return config_depth_model
|
||||
|
||||
if sd2_cond_proj_weight is not None and sd2_cond_proj_weight.shape[1] == 1024:
|
||||
if diffusion_model_input.shape[1] == 9:
|
||||
return config_sd2_inpainting
|
||||
elif is_using_v_parameterization_for_sd2(sd):
|
||||
return config_sd2v
|
||||
else:
|
||||
return config_sd2
|
||||
|
||||
if diffusion_model_input is not None:
|
||||
if diffusion_model_input.shape[1] == 9:
|
||||
return config_inpainting
|
||||
if diffusion_model_input.shape[1] == 8:
|
||||
return config_instruct_pix2pix
|
||||
|
||||
if sd.get('cond_stage_model.roberta.embeddings.word_embeddings.weight', None) is not None:
|
||||
return config_alt_diffusion
|
||||
|
||||
return config_default
|
||||
|
||||
|
||||
def find_checkpoint_config(state_dict, info):
|
||||
if info is None:
|
||||
return guess_model_config_from_state_dict(state_dict, "")
|
||||
|
||||
config = find_checkpoint_config_near_filename(info)
|
||||
if config is not None:
|
||||
return config
|
||||
|
||||
return guess_model_config_from_state_dict(state_dict, info.filename)
|
||||
|
||||
|
||||
def find_checkpoint_config_near_filename(info):
|
||||
if info is None:
|
||||
return None
|
||||
|
||||
config = os.path.splitext(info.filename)[0] + ".yaml"
|
||||
if os.path.exists(config):
|
||||
return config
|
||||
|
||||
return None
|
||||
|
|
@ -1,11 +1,53 @@
|
|||
from modules import sd_samplers_compvis, sd_samplers_kdiffusion, shared
|
||||
from collections import namedtuple, deque
|
||||
import numpy as np
|
||||
from math import floor
|
||||
import torch
|
||||
import tqdm
|
||||
from PIL import Image
|
||||
import inspect
|
||||
import k_diffusion.sampling
|
||||
import torchsde._brownian.brownian_interval
|
||||
import ldm.models.diffusion.ddim
|
||||
import ldm.models.diffusion.plms
|
||||
from modules import prompt_parser, devices, processing, images, sd_vae_approx
|
||||
|
||||
# imports for functions that previously were here and are used by other modules
|
||||
from modules.sd_samplers_common import samples_to_image_grid, sample_to_image
|
||||
from modules.shared import opts, cmd_opts, state
|
||||
import modules.shared as shared
|
||||
from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback
|
||||
|
||||
|
||||
SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
|
||||
|
||||
samplers_k_diffusion = [
|
||||
('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {}),
|
||||
('Euler', 'sample_euler', ['k_euler'], {}),
|
||||
('LMS', 'sample_lms', ['k_lms'], {}),
|
||||
('Heun', 'sample_heun', ['k_heun'], {}),
|
||||
('DPM2', 'sample_dpm_2', ['k_dpm_2'], {'discard_next_to_last_sigma': True}),
|
||||
('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'discard_next_to_last_sigma': True}),
|
||||
('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {}),
|
||||
('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}),
|
||||
('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {}),
|
||||
('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {}),
|
||||
('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {}),
|
||||
('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}),
|
||||
('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True}),
|
||||
('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True}),
|
||||
('DPM++ 2S a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras'}),
|
||||
('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}),
|
||||
('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras'}),
|
||||
]
|
||||
|
||||
samplers_data_k_diffusion = [
|
||||
SamplerData(label, lambda model, funcname=funcname: KDiffusionSampler(funcname, model), aliases, options)
|
||||
for label, funcname, aliases, options in samplers_k_diffusion
|
||||
if hasattr(k_diffusion.sampling, funcname)
|
||||
]
|
||||
|
||||
all_samplers = [
|
||||
*sd_samplers_kdiffusion.samplers_data_k_diffusion,
|
||||
*sd_samplers_compvis.samplers_data_compvis,
|
||||
*samplers_data_k_diffusion,
|
||||
SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {}),
|
||||
SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {}),
|
||||
]
|
||||
all_samplers_map = {x.name: x for x in all_samplers}
|
||||
|
||||
|
@ -31,8 +73,8 @@ def create_sampler(name, model):
|
|||
def set_samplers():
|
||||
global samplers, samplers_for_img2img
|
||||
|
||||
hidden = set(shared.opts.hide_samplers)
|
||||
hidden_img2img = set(shared.opts.hide_samplers + ['PLMS'])
|
||||
hidden = set(opts.hide_samplers)
|
||||
hidden_img2img = set(opts.hide_samplers + ['PLMS'])
|
||||
|
||||
samplers = [x for x in all_samplers if x.name not in hidden]
|
||||
samplers_for_img2img = [x for x in all_samplers if x.name not in hidden_img2img]
|
||||
|
@ -45,3 +87,464 @@ def set_samplers():
|
|||
|
||||
|
||||
set_samplers()
|
||||
|
||||
sampler_extra_params = {
|
||||
'sample_euler': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
|
||||
'sample_heun': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
|
||||
'sample_dpm_2': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
|
||||
}
|
||||
|
||||
|
||||
def setup_img2img_steps(p, steps=None):
|
||||
if opts.img2img_fix_steps or steps is not None:
|
||||
requested_steps = (steps or p.steps)
|
||||
steps = int(requested_steps / min(p.denoising_strength, 0.999)) if p.denoising_strength > 0 else 0
|
||||
t_enc = requested_steps - 1
|
||||
else:
|
||||
steps = p.steps
|
||||
t_enc = int(min(p.denoising_strength, 0.999) * steps)
|
||||
|
||||
return steps, t_enc
|
||||
|
||||
|
||||
approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2}
|
||||
|
||||
|
||||
def single_sample_to_image(sample, approximation=None):
|
||||
if approximation is None:
|
||||
approximation = approximation_indexes.get(opts.show_progress_type, 0)
|
||||
|
||||
if approximation == 2:
|
||||
x_sample = sd_vae_approx.cheap_approximation(sample)
|
||||
elif approximation == 1:
|
||||
x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach()
|
||||
else:
|
||||
x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0]
|
||||
|
||||
x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
|
||||
x_sample = x_sample.astype(np.uint8)
|
||||
return Image.fromarray(x_sample)
|
||||
|
||||
|
||||
def sample_to_image(samples, index=0, approximation=None):
|
||||
return single_sample_to_image(samples[index], approximation)
|
||||
|
||||
|
||||
def samples_to_image_grid(samples, approximation=None):
|
||||
return images.image_grid([single_sample_to_image(sample, approximation) for sample in samples])
|
||||
|
||||
|
||||
def store_latent(decoded):
|
||||
state.current_latent = decoded
|
||||
|
||||
if opts.live_previews_enable and opts.show_progress_every_n_steps > 0 and shared.state.sampling_step % opts.show_progress_every_n_steps == 0:
|
||||
if not shared.parallel_processing_allowed:
|
||||
shared.state.assign_current_image(sample_to_image(decoded))
|
||||
|
||||
|
||||
class InterruptedException(BaseException):
|
||||
pass
|
||||
|
||||
|
||||
class VanillaStableDiffusionSampler:
|
||||
def __init__(self, constructor, sd_model):
|
||||
self.sampler = constructor(sd_model)
|
||||
self.is_plms = hasattr(self.sampler, 'p_sample_plms')
|
||||
self.orig_p_sample_ddim = self.sampler.p_sample_plms if self.is_plms else self.sampler.p_sample_ddim
|
||||
self.mask = None
|
||||
self.nmask = None
|
||||
self.init_latent = None
|
||||
self.sampler_noises = None
|
||||
self.step = 0
|
||||
self.stop_at = None
|
||||
self.eta = None
|
||||
self.default_eta = 0.0
|
||||
self.config = None
|
||||
self.last_latent = None
|
||||
|
||||
self.conditioning_key = sd_model.model.conditioning_key
|
||||
|
||||
def number_of_needed_noises(self, p):
|
||||
return 0
|
||||
|
||||
def launch_sampling(self, steps, func):
|
||||
state.sampling_steps = steps
|
||||
state.sampling_step = 0
|
||||
|
||||
try:
|
||||
return func()
|
||||
except InterruptedException:
|
||||
return self.last_latent
|
||||
|
||||
def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs):
|
||||
if state.interrupted or state.skipped:
|
||||
raise InterruptedException
|
||||
|
||||
if self.stop_at is not None and self.step > self.stop_at:
|
||||
raise InterruptedException
|
||||
|
||||
# Have to unwrap the inpainting conditioning here to perform pre-processing
|
||||
image_conditioning = None
|
||||
if isinstance(cond, dict):
|
||||
image_conditioning = cond["c_concat"][0]
|
||||
cond = cond["c_crossattn"][0]
|
||||
unconditional_conditioning = unconditional_conditioning["c_crossattn"][0]
|
||||
|
||||
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
|
||||
unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step)
|
||||
|
||||
assert all([len(conds) == 1 for conds in conds_list]), 'composition via AND is not supported for DDIM/PLMS samplers'
|
||||
cond = tensor
|
||||
|
||||
# for DDIM, shapes must match, we can't just process cond and uncond independently;
|
||||
# filling unconditional_conditioning with repeats of the last vector to match length is
|
||||
# not 100% correct but should work well enough
|
||||
if unconditional_conditioning.shape[1] < cond.shape[1]:
|
||||
last_vector = unconditional_conditioning[:, -1:]
|
||||
last_vector_repeated = last_vector.repeat([1, cond.shape[1] - unconditional_conditioning.shape[1], 1])
|
||||
unconditional_conditioning = torch.hstack([unconditional_conditioning, last_vector_repeated])
|
||||
elif unconditional_conditioning.shape[1] > cond.shape[1]:
|
||||
unconditional_conditioning = unconditional_conditioning[:, :cond.shape[1]]
|
||||
|
||||
if self.mask is not None:
|
||||
img_orig = self.sampler.model.q_sample(self.init_latent, ts)
|
||||
x_dec = img_orig * self.mask + self.nmask * x_dec
|
||||
|
||||
# Wrap the image conditioning back up since the DDIM code can accept the dict directly.
|
||||
# Note that they need to be lists because it just concatenates them later.
|
||||
if image_conditioning is not None:
|
||||
cond = {"c_concat": [image_conditioning], "c_crossattn": [cond]}
|
||||
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
|
||||
|
||||
res = self.orig_p_sample_ddim(x_dec, cond, ts, unconditional_conditioning=unconditional_conditioning, *args, **kwargs)
|
||||
|
||||
if self.mask is not None:
|
||||
self.last_latent = self.init_latent * self.mask + self.nmask * res[1]
|
||||
else:
|
||||
self.last_latent = res[1]
|
||||
|
||||
store_latent(self.last_latent)
|
||||
|
||||
self.step += 1
|
||||
state.sampling_step = self.step
|
||||
shared.total_tqdm.update()
|
||||
|
||||
return res
|
||||
|
||||
def initialize(self, p):
|
||||
self.eta = p.eta if p.eta is not None else opts.eta_ddim
|
||||
|
||||
for fieldname in ['p_sample_ddim', 'p_sample_plms']:
|
||||
if hasattr(self.sampler, fieldname):
|
||||
setattr(self.sampler, fieldname, self.p_sample_ddim_hook)
|
||||
|
||||
self.mask = p.mask if hasattr(p, 'mask') else None
|
||||
self.nmask = p.nmask if hasattr(p, 'nmask') else None
|
||||
|
||||
def adjust_steps_if_invalid(self, p, num_steps):
|
||||
if (self.config.name == 'DDIM' and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS'):
|
||||
valid_step = 999 / (1000 // num_steps)
|
||||
if valid_step == floor(valid_step):
|
||||
return int(valid_step) + 1
|
||||
|
||||
return num_steps
|
||||
|
||||
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
||||
steps, t_enc = setup_img2img_steps(p, steps)
|
||||
steps = self.adjust_steps_if_invalid(p, steps)
|
||||
self.initialize(p)
|
||||
|
||||
self.sampler.make_schedule(ddim_num_steps=steps, ddim_eta=self.eta, ddim_discretize=p.ddim_discretize, verbose=False)
|
||||
x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(shared.device), noise=noise)
|
||||
|
||||
self.init_latent = x
|
||||
self.last_latent = x
|
||||
self.step = 0
|
||||
|
||||
# Wrap the conditioning models with additional image conditioning for inpainting model
|
||||
if image_conditioning is not None:
|
||||
conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]}
|
||||
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
|
||||
|
||||
samples = self.launch_sampling(t_enc + 1, lambda: self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning))
|
||||
|
||||
return samples
|
||||
|
||||
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
||||
self.initialize(p)
|
||||
|
||||
self.init_latent = None
|
||||
self.last_latent = x
|
||||
self.step = 0
|
||||
|
||||
steps = self.adjust_steps_if_invalid(p, steps or p.steps)
|
||||
|
||||
# Wrap the conditioning models with additional image conditioning for inpainting model
|
||||
# dummy_for_plms is needed because PLMS code checks the first item in the dict to have the right shape
|
||||
if image_conditioning is not None:
|
||||
conditioning = {"dummy_for_plms": np.zeros((conditioning.shape[0],)), "c_crossattn": [conditioning], "c_concat": [image_conditioning]}
|
||||
unconditional_conditioning = {"c_crossattn": [unconditional_conditioning], "c_concat": [image_conditioning]}
|
||||
|
||||
samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
|
||||
|
||||
return samples_ddim
|
||||
|
||||
|
||||
class CFGDenoiser(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.inner_model = model
|
||||
self.mask = None
|
||||
self.nmask = None
|
||||
self.init_latent = None
|
||||
self.step = 0
|
||||
|
||||
def combine_denoised(self, x_out, conds_list, uncond, cond_scale):
|
||||
denoised_uncond = x_out[-uncond.shape[0]:]
|
||||
denoised = torch.clone(denoised_uncond)
|
||||
|
||||
for i, conds in enumerate(conds_list):
|
||||
for cond_index, weight in conds:
|
||||
denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale)
|
||||
|
||||
return denoised
|
||||
|
||||
def forward(self, x, sigma, uncond, cond, cond_scale, image_cond):
|
||||
if state.interrupted or state.skipped:
|
||||
raise InterruptedException
|
||||
|
||||
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
|
||||
uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
|
||||
|
||||
batch_size = len(conds_list)
|
||||
repeats = [len(conds_list[i]) for i in range(batch_size)]
|
||||
|
||||
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
|
||||
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond])
|
||||
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
|
||||
|
||||
denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps)
|
||||
cfg_denoiser_callback(denoiser_params)
|
||||
x_in = denoiser_params.x
|
||||
image_cond_in = denoiser_params.image_cond
|
||||
sigma_in = denoiser_params.sigma
|
||||
|
||||
if tensor.shape[1] == uncond.shape[1]:
|
||||
cond_in = torch.cat([tensor, uncond])
|
||||
|
||||
if shared.batch_cond_uncond:
|
||||
x_out = self.inner_model(x_in, sigma_in, cond={"c_crossattn": [cond_in], "c_concat": [image_cond_in]})
|
||||
else:
|
||||
x_out = torch.zeros_like(x_in)
|
||||
for batch_offset in range(0, x_out.shape[0], batch_size):
|
||||
a = batch_offset
|
||||
b = a + batch_size
|
||||
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": [cond_in[a:b]], "c_concat": [image_cond_in[a:b]]})
|
||||
else:
|
||||
x_out = torch.zeros_like(x_in)
|
||||
batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size
|
||||
for batch_offset in range(0, tensor.shape[0], batch_size):
|
||||
a = batch_offset
|
||||
b = min(a + batch_size, tensor.shape[0])
|
||||
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": [tensor[a:b]], "c_concat": [image_cond_in[a:b]]})
|
||||
|
||||
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond={"c_crossattn": [uncond], "c_concat": [image_cond_in[-uncond.shape[0]:]]})
|
||||
|
||||
if opts.live_preview_content == "Prompt":
|
||||
store_latent(x_out[0:uncond.shape[0]])
|
||||
elif opts.live_preview_content == "Negative prompt":
|
||||
store_latent(x_out[-uncond.shape[0]:])
|
||||
|
||||
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
|
||||
|
||||
if self.mask is not None:
|
||||
denoised = self.init_latent * self.mask + self.nmask * denoised
|
||||
|
||||
self.step += 1
|
||||
|
||||
return denoised
|
||||
|
||||
|
||||
class TorchHijack:
|
||||
def __init__(self, sampler_noises):
|
||||
# Using a deque to efficiently receive the sampler_noises in the same order as the previous index-based
|
||||
# implementation.
|
||||
self.sampler_noises = deque(sampler_noises)
|
||||
|
||||
def __getattr__(self, item):
|
||||
if item == 'randn_like':
|
||||
return self.randn_like
|
||||
|
||||
if hasattr(torch, item):
|
||||
return getattr(torch, item)
|
||||
|
||||
raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item))
|
||||
|
||||
def randn_like(self, x):
|
||||
if self.sampler_noises:
|
||||
noise = self.sampler_noises.popleft()
|
||||
if noise.shape == x.shape:
|
||||
return noise
|
||||
|
||||
if x.device.type == 'mps':
|
||||
return torch.randn_like(x, device=devices.cpu).to(x.device)
|
||||
else:
|
||||
return torch.randn_like(x)
|
||||
|
||||
|
||||
# MPS fix for randn in torchsde
|
||||
def torchsde_randn(size, dtype, device, seed):
|
||||
if device.type == 'mps':
|
||||
generator = torch.Generator(devices.cpu).manual_seed(int(seed))
|
||||
return torch.randn(size, dtype=dtype, device=devices.cpu, generator=generator).to(device)
|
||||
else:
|
||||
generator = torch.Generator(device).manual_seed(int(seed))
|
||||
return torch.randn(size, dtype=dtype, device=device, generator=generator)
|
||||
|
||||
|
||||
torchsde._brownian.brownian_interval._randn = torchsde_randn
|
||||
|
||||
|
||||
class KDiffusionSampler:
|
||||
def __init__(self, funcname, sd_model):
|
||||
denoiser = k_diffusion.external.CompVisVDenoiser if sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser
|
||||
|
||||
self.model_wrap = denoiser(sd_model, quantize=shared.opts.enable_quantization)
|
||||
self.funcname = funcname
|
||||
self.func = getattr(k_diffusion.sampling, self.funcname)
|
||||
self.extra_params = sampler_extra_params.get(funcname, [])
|
||||
self.model_wrap_cfg = CFGDenoiser(self.model_wrap)
|
||||
self.sampler_noises = None
|
||||
self.stop_at = None
|
||||
self.eta = None
|
||||
self.default_eta = 1.0
|
||||
self.config = None
|
||||
self.last_latent = None
|
||||
|
||||
self.conditioning_key = sd_model.model.conditioning_key
|
||||
|
||||
def callback_state(self, d):
|
||||
step = d['i']
|
||||
latent = d["denoised"]
|
||||
if opts.live_preview_content == "Combined":
|
||||
store_latent(latent)
|
||||
self.last_latent = latent
|
||||
|
||||
if self.stop_at is not None and step > self.stop_at:
|
||||
raise InterruptedException
|
||||
|
||||
state.sampling_step = step
|
||||
shared.total_tqdm.update()
|
||||
|
||||
def launch_sampling(self, steps, func):
|
||||
state.sampling_steps = steps
|
||||
state.sampling_step = 0
|
||||
|
||||
try:
|
||||
return func()
|
||||
except InterruptedException:
|
||||
return self.last_latent
|
||||
|
||||
def number_of_needed_noises(self, p):
|
||||
return p.steps
|
||||
|
||||
def initialize(self, p):
|
||||
self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None
|
||||
self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None
|
||||
self.model_wrap.step = 0
|
||||
self.eta = p.eta or opts.eta_ancestral
|
||||
|
||||
k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else [])
|
||||
|
||||
extra_params_kwargs = {}
|
||||
for param_name in self.extra_params:
|
||||
if hasattr(p, param_name) and param_name in inspect.signature(self.func).parameters:
|
||||
extra_params_kwargs[param_name] = getattr(p, param_name)
|
||||
|
||||
if 'eta' in inspect.signature(self.func).parameters:
|
||||
extra_params_kwargs['eta'] = self.eta
|
||||
|
||||
return extra_params_kwargs
|
||||
|
||||
def get_sigmas(self, p, steps):
|
||||
discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False)
|
||||
if opts.always_discard_next_to_last_sigma and not discard_next_to_last_sigma:
|
||||
discard_next_to_last_sigma = True
|
||||
p.extra_generation_params["Discard penultimate sigma"] = True
|
||||
|
||||
steps += 1 if discard_next_to_last_sigma else 0
|
||||
|
||||
if p.sampler_noise_scheduler_override:
|
||||
sigmas = p.sampler_noise_scheduler_override(steps)
|
||||
elif self.config is not None and self.config.options.get('scheduler', None) == 'karras':
|
||||
sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item())
|
||||
|
||||
sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, device=shared.device)
|
||||
else:
|
||||
sigmas = self.model_wrap.get_sigmas(steps)
|
||||
|
||||
if discard_next_to_last_sigma:
|
||||
sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])
|
||||
|
||||
return sigmas
|
||||
|
||||
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
||||
steps, t_enc = setup_img2img_steps(p, steps)
|
||||
|
||||
sigmas = self.get_sigmas(p, steps)
|
||||
|
||||
sigma_sched = sigmas[steps - t_enc - 1:]
|
||||
xi = x + noise * sigma_sched[0]
|
||||
|
||||
extra_params_kwargs = self.initialize(p)
|
||||
if 'sigma_min' in inspect.signature(self.func).parameters:
|
||||
## last sigma is zero which isn't allowed by DPM Fast & Adaptive so taking value before last
|
||||
extra_params_kwargs['sigma_min'] = sigma_sched[-2]
|
||||
if 'sigma_max' in inspect.signature(self.func).parameters:
|
||||
extra_params_kwargs['sigma_max'] = sigma_sched[0]
|
||||
if 'n' in inspect.signature(self.func).parameters:
|
||||
extra_params_kwargs['n'] = len(sigma_sched) - 1
|
||||
if 'sigma_sched' in inspect.signature(self.func).parameters:
|
||||
extra_params_kwargs['sigma_sched'] = sigma_sched
|
||||
if 'sigmas' in inspect.signature(self.func).parameters:
|
||||
extra_params_kwargs['sigmas'] = sigma_sched
|
||||
|
||||
self.model_wrap_cfg.init_latent = x
|
||||
self.last_latent = x
|
||||
|
||||
samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args={
|
||||
'cond': conditioning,
|
||||
'image_cond': image_conditioning,
|
||||
'uncond': unconditional_conditioning,
|
||||
'cond_scale': p.cfg_scale
|
||||
}, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
||||
|
||||
return samples
|
||||
|
||||
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning = None):
|
||||
steps = steps or p.steps
|
||||
|
||||
sigmas = self.get_sigmas(p, steps)
|
||||
|
||||
x = x * sigmas[0]
|
||||
|
||||
extra_params_kwargs = self.initialize(p)
|
||||
if 'sigma_min' in inspect.signature(self.func).parameters:
|
||||
extra_params_kwargs['sigma_min'] = self.model_wrap.sigmas[0].item()
|
||||
extra_params_kwargs['sigma_max'] = self.model_wrap.sigmas[-1].item()
|
||||
if 'n' in inspect.signature(self.func).parameters:
|
||||
extra_params_kwargs['n'] = steps
|
||||
else:
|
||||
extra_params_kwargs['sigmas'] = sigmas
|
||||
|
||||
self.last_latent = x
|
||||
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={
|
||||
'cond': conditioning,
|
||||
'image_cond': image_conditioning,
|
||||
'uncond': unconditional_conditioning,
|
||||
'cond_scale': p.cfg_scale
|
||||
}, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
||||
|
||||
return samples
|
||||
|
||||
|
|
|
@ -1,62 +0,0 @@
|
|||
from collections import namedtuple
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from modules import devices, processing, images, sd_vae_approx
|
||||
|
||||
from modules.shared import opts, state
|
||||
import modules.shared as shared
|
||||
|
||||
SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
|
||||
|
||||
|
||||
def setup_img2img_steps(p, steps=None):
|
||||
if opts.img2img_fix_steps or steps is not None:
|
||||
requested_steps = (steps or p.steps)
|
||||
steps = int(requested_steps / min(p.denoising_strength, 0.999)) if p.denoising_strength > 0 else 0
|
||||
t_enc = requested_steps - 1
|
||||
else:
|
||||
steps = p.steps
|
||||
t_enc = int(min(p.denoising_strength, 0.999) * steps)
|
||||
|
||||
return steps, t_enc
|
||||
|
||||
|
||||
approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2}
|
||||
|
||||
|
||||
def single_sample_to_image(sample, approximation=None):
|
||||
if approximation is None:
|
||||
approximation = approximation_indexes.get(opts.show_progress_type, 0)
|
||||
|
||||
if approximation == 2:
|
||||
x_sample = sd_vae_approx.cheap_approximation(sample)
|
||||
elif approximation == 1:
|
||||
x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach()
|
||||
else:
|
||||
x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0]
|
||||
|
||||
x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
|
||||
x_sample = x_sample.astype(np.uint8)
|
||||
return Image.fromarray(x_sample)
|
||||
|
||||
|
||||
def sample_to_image(samples, index=0, approximation=None):
|
||||
return single_sample_to_image(samples[index], approximation)
|
||||
|
||||
|
||||
def samples_to_image_grid(samples, approximation=None):
|
||||
return images.image_grid([single_sample_to_image(sample, approximation) for sample in samples])
|
||||
|
||||
|
||||
def store_latent(decoded):
|
||||
state.current_latent = decoded
|
||||
|
||||
if opts.live_previews_enable and opts.show_progress_every_n_steps > 0 and shared.state.sampling_step % opts.show_progress_every_n_steps == 0:
|
||||
if not shared.parallel_processing_allowed:
|
||||
shared.state.assign_current_image(sample_to_image(decoded))
|
||||
|
||||
|
||||
class InterruptedException(BaseException):
|
||||
pass
|
|
@ -1,160 +0,0 @@
|
|||
import math
|
||||
import ldm.models.diffusion.ddim
|
||||
import ldm.models.diffusion.plms
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from modules.shared import state
|
||||
from modules import sd_samplers_common, prompt_parser, shared
|
||||
|
||||
|
||||
samplers_data_compvis = [
|
||||
sd_samplers_common.SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {}),
|
||||
sd_samplers_common.SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {}),
|
||||
]
|
||||
|
||||
|
||||
class VanillaStableDiffusionSampler:
|
||||
def __init__(self, constructor, sd_model):
|
||||
self.sampler = constructor(sd_model)
|
||||
self.is_plms = hasattr(self.sampler, 'p_sample_plms')
|
||||
self.orig_p_sample_ddim = self.sampler.p_sample_plms if self.is_plms else self.sampler.p_sample_ddim
|
||||
self.mask = None
|
||||
self.nmask = None
|
||||
self.init_latent = None
|
||||
self.sampler_noises = None
|
||||
self.step = 0
|
||||
self.stop_at = None
|
||||
self.eta = None
|
||||
self.config = None
|
||||
self.last_latent = None
|
||||
|
||||
self.conditioning_key = sd_model.model.conditioning_key
|
||||
|
||||
def number_of_needed_noises(self, p):
|
||||
return 0
|
||||
|
||||
def launch_sampling(self, steps, func):
|
||||
state.sampling_steps = steps
|
||||
state.sampling_step = 0
|
||||
|
||||
try:
|
||||
return func()
|
||||
except sd_samplers_common.InterruptedException:
|
||||
return self.last_latent
|
||||
|
||||
def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs):
|
||||
if state.interrupted or state.skipped:
|
||||
raise sd_samplers_common.InterruptedException
|
||||
|
||||
if self.stop_at is not None and self.step > self.stop_at:
|
||||
raise sd_samplers_common.InterruptedException
|
||||
|
||||
# Have to unwrap the inpainting conditioning here to perform pre-processing
|
||||
image_conditioning = None
|
||||
if isinstance(cond, dict):
|
||||
image_conditioning = cond["c_concat"][0]
|
||||
cond = cond["c_crossattn"][0]
|
||||
unconditional_conditioning = unconditional_conditioning["c_crossattn"][0]
|
||||
|
||||
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
|
||||
unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step)
|
||||
|
||||
assert all([len(conds) == 1 for conds in conds_list]), 'composition via AND is not supported for DDIM/PLMS samplers'
|
||||
cond = tensor
|
||||
|
||||
# for DDIM, shapes must match, we can't just process cond and uncond independently;
|
||||
# filling unconditional_conditioning with repeats of the last vector to match length is
|
||||
# not 100% correct but should work well enough
|
||||
if unconditional_conditioning.shape[1] < cond.shape[1]:
|
||||
last_vector = unconditional_conditioning[:, -1:]
|
||||
last_vector_repeated = last_vector.repeat([1, cond.shape[1] - unconditional_conditioning.shape[1], 1])
|
||||
unconditional_conditioning = torch.hstack([unconditional_conditioning, last_vector_repeated])
|
||||
elif unconditional_conditioning.shape[1] > cond.shape[1]:
|
||||
unconditional_conditioning = unconditional_conditioning[:, :cond.shape[1]]
|
||||
|
||||
if self.mask is not None:
|
||||
img_orig = self.sampler.model.q_sample(self.init_latent, ts)
|
||||
x_dec = img_orig * self.mask + self.nmask * x_dec
|
||||
|
||||
# Wrap the image conditioning back up since the DDIM code can accept the dict directly.
|
||||
# Note that they need to be lists because it just concatenates them later.
|
||||
if image_conditioning is not None:
|
||||
cond = {"c_concat": [image_conditioning], "c_crossattn": [cond]}
|
||||
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
|
||||
|
||||
res = self.orig_p_sample_ddim(x_dec, cond, ts, unconditional_conditioning=unconditional_conditioning, *args, **kwargs)
|
||||
|
||||
if self.mask is not None:
|
||||
self.last_latent = self.init_latent * self.mask + self.nmask * res[1]
|
||||
else:
|
||||
self.last_latent = res[1]
|
||||
|
||||
sd_samplers_common.store_latent(self.last_latent)
|
||||
|
||||
self.step += 1
|
||||
state.sampling_step = self.step
|
||||
shared.total_tqdm.update()
|
||||
|
||||
return res
|
||||
|
||||
def initialize(self, p):
|
||||
self.eta = p.eta if p.eta is not None else shared.opts.eta_ddim
|
||||
if self.eta != 0.0:
|
||||
p.extra_generation_params["Eta DDIM"] = self.eta
|
||||
|
||||
for fieldname in ['p_sample_ddim', 'p_sample_plms']:
|
||||
if hasattr(self.sampler, fieldname):
|
||||
setattr(self.sampler, fieldname, self.p_sample_ddim_hook)
|
||||
|
||||
self.mask = p.mask if hasattr(p, 'mask') else None
|
||||
self.nmask = p.nmask if hasattr(p, 'nmask') else None
|
||||
|
||||
def adjust_steps_if_invalid(self, p, num_steps):
|
||||
if (self.config.name == 'DDIM' and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS'):
|
||||
valid_step = 999 / (1000 // num_steps)
|
||||
if valid_step == math.floor(valid_step):
|
||||
return int(valid_step) + 1
|
||||
|
||||
return num_steps
|
||||
|
||||
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
||||
steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps)
|
||||
steps = self.adjust_steps_if_invalid(p, steps)
|
||||
self.initialize(p)
|
||||
|
||||
self.sampler.make_schedule(ddim_num_steps=steps, ddim_eta=self.eta, ddim_discretize=p.ddim_discretize, verbose=False)
|
||||
x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(shared.device), noise=noise)
|
||||
|
||||
self.init_latent = x
|
||||
self.last_latent = x
|
||||
self.step = 0
|
||||
|
||||
# Wrap the conditioning models with additional image conditioning for inpainting model
|
||||
if image_conditioning is not None:
|
||||
conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]}
|
||||
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
|
||||
|
||||
samples = self.launch_sampling(t_enc + 1, lambda: self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning))
|
||||
|
||||
return samples
|
||||
|
||||
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
||||
self.initialize(p)
|
||||
|
||||
self.init_latent = None
|
||||
self.last_latent = x
|
||||
self.step = 0
|
||||
|
||||
steps = self.adjust_steps_if_invalid(p, steps or p.steps)
|
||||
|
||||
# Wrap the conditioning models with additional image conditioning for inpainting model
|
||||
# dummy_for_plms is needed because PLMS code checks the first item in the dict to have the right shape
|
||||
if image_conditioning is not None:
|
||||
conditioning = {"dummy_for_plms": np.zeros((conditioning.shape[0],)), "c_crossattn": [conditioning], "c_concat": [image_conditioning]}
|
||||
unconditional_conditioning = {"c_crossattn": [unconditional_conditioning], "c_concat": [image_conditioning]}
|
||||
|
||||
samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
|
||||
|
||||
return samples_ddim
|
|
@ -1,331 +0,0 @@
|
|||
from collections import deque
|
||||
import torch
|
||||
import inspect
|
||||
import einops
|
||||
import k_diffusion.sampling
|
||||
from modules import prompt_parser, devices, sd_samplers_common
|
||||
|
||||
from modules.shared import opts, state
|
||||
import modules.shared as shared
|
||||
from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback
|
||||
|
||||
samplers_k_diffusion = [
|
||||
('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {}),
|
||||
('Euler', 'sample_euler', ['k_euler'], {}),
|
||||
('LMS', 'sample_lms', ['k_lms'], {}),
|
||||
('Heun', 'sample_heun', ['k_heun'], {}),
|
||||
('DPM2', 'sample_dpm_2', ['k_dpm_2'], {'discard_next_to_last_sigma': True}),
|
||||
('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'discard_next_to_last_sigma': True}),
|
||||
('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {}),
|
||||
('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}),
|
||||
('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {}),
|
||||
('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {}),
|
||||
('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {}),
|
||||
('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}),
|
||||
('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True}),
|
||||
('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True}),
|
||||
('DPM++ 2S a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras'}),
|
||||
('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}),
|
||||
('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras'}),
|
||||
]
|
||||
|
||||
samplers_data_k_diffusion = [
|
||||
sd_samplers_common.SamplerData(label, lambda model, funcname=funcname: KDiffusionSampler(funcname, model), aliases, options)
|
||||
for label, funcname, aliases, options in samplers_k_diffusion
|
||||
if hasattr(k_diffusion.sampling, funcname)
|
||||
]
|
||||
|
||||
sampler_extra_params = {
|
||||
'sample_euler': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
|
||||
'sample_heun': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
|
||||
'sample_dpm_2': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
|
||||
}
|
||||
|
||||
|
||||
class CFGDenoiser(torch.nn.Module):
|
||||
"""
|
||||
Classifier free guidance denoiser. A wrapper for stable diffusion model (specifically for unet)
|
||||
that can take a noisy picture and produce a noise-free picture using two guidances (prompts)
|
||||
instead of one. Originally, the second prompt is just an empty string, but we use non-empty
|
||||
negative prompt.
|
||||
"""
|
||||
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.inner_model = model
|
||||
self.mask = None
|
||||
self.nmask = None
|
||||
self.init_latent = None
|
||||
self.step = 0
|
||||
self.image_cfg_scale = None
|
||||
|
||||
def combine_denoised(self, x_out, conds_list, uncond, cond_scale):
|
||||
denoised_uncond = x_out[-uncond.shape[0]:]
|
||||
denoised = torch.clone(denoised_uncond)
|
||||
|
||||
for i, conds in enumerate(conds_list):
|
||||
for cond_index, weight in conds:
|
||||
denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale)
|
||||
|
||||
return denoised
|
||||
|
||||
def combine_denoised_for_edit_model(self, x_out, cond_scale):
|
||||
out_cond, out_img_cond, out_uncond = x_out.chunk(3)
|
||||
denoised = out_uncond + cond_scale * (out_cond - out_img_cond) + self.image_cfg_scale * (out_img_cond - out_uncond)
|
||||
|
||||
return denoised
|
||||
|
||||
def forward(self, x, sigma, uncond, cond, cond_scale, image_cond):
|
||||
if state.interrupted or state.skipped:
|
||||
raise sd_samplers_common.InterruptedException
|
||||
|
||||
# at self.image_cfg_scale == 1.0 produced results for edit model are the same as with normal sampling,
|
||||
# so is_edit_model is set to False to support AND composition.
|
||||
is_edit_model = shared.sd_model.cond_stage_key == "edit" and self.image_cfg_scale is not None and self.image_cfg_scale != 1.0
|
||||
|
||||
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
|
||||
uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
|
||||
|
||||
assert not is_edit_model or all([len(conds) == 1 for conds in conds_list]), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"
|
||||
|
||||
batch_size = len(conds_list)
|
||||
repeats = [len(conds_list[i]) for i in range(batch_size)]
|
||||
|
||||
if not is_edit_model:
|
||||
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
|
||||
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
|
||||
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond])
|
||||
else:
|
||||
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x] + [x])
|
||||
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma])
|
||||
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond] + [torch.zeros_like(self.init_latent)])
|
||||
|
||||
denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps)
|
||||
cfg_denoiser_callback(denoiser_params)
|
||||
x_in = denoiser_params.x
|
||||
image_cond_in = denoiser_params.image_cond
|
||||
sigma_in = denoiser_params.sigma
|
||||
|
||||
if tensor.shape[1] == uncond.shape[1]:
|
||||
if not is_edit_model:
|
||||
cond_in = torch.cat([tensor, uncond])
|
||||
else:
|
||||
cond_in = torch.cat([tensor, uncond, uncond])
|
||||
|
||||
if shared.batch_cond_uncond:
|
||||
x_out = self.inner_model(x_in, sigma_in, cond={"c_crossattn": [cond_in], "c_concat": [image_cond_in]})
|
||||
else:
|
||||
x_out = torch.zeros_like(x_in)
|
||||
for batch_offset in range(0, x_out.shape[0], batch_size):
|
||||
a = batch_offset
|
||||
b = a + batch_size
|
||||
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": [cond_in[a:b]], "c_concat": [image_cond_in[a:b]]})
|
||||
else:
|
||||
x_out = torch.zeros_like(x_in)
|
||||
batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size
|
||||
for batch_offset in range(0, tensor.shape[0], batch_size):
|
||||
a = batch_offset
|
||||
b = min(a + batch_size, tensor.shape[0])
|
||||
|
||||
if not is_edit_model:
|
||||
c_crossattn = [tensor[a:b]]
|
||||
else:
|
||||
c_crossattn = torch.cat([tensor[a:b]], uncond)
|
||||
|
||||
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": c_crossattn, "c_concat": [image_cond_in[a:b]]})
|
||||
|
||||
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond={"c_crossattn": [uncond], "c_concat": [image_cond_in[-uncond.shape[0]:]]})
|
||||
|
||||
devices.test_for_nans(x_out, "unet")
|
||||
|
||||
if opts.live_preview_content == "Prompt":
|
||||
sd_samplers_common.store_latent(x_out[0:uncond.shape[0]])
|
||||
elif opts.live_preview_content == "Negative prompt":
|
||||
sd_samplers_common.store_latent(x_out[-uncond.shape[0]:])
|
||||
|
||||
if not is_edit_model:
|
||||
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
|
||||
else:
|
||||
denoised = self.combine_denoised_for_edit_model(x_out, cond_scale)
|
||||
|
||||
if self.mask is not None:
|
||||
denoised = self.init_latent * self.mask + self.nmask * denoised
|
||||
|
||||
self.step += 1
|
||||
|
||||
return denoised
|
||||
|
||||
|
||||
class TorchHijack:
|
||||
def __init__(self, sampler_noises):
|
||||
# Using a deque to efficiently receive the sampler_noises in the same order as the previous index-based
|
||||
# implementation.
|
||||
self.sampler_noises = deque(sampler_noises)
|
||||
|
||||
def __getattr__(self, item):
|
||||
if item == 'randn_like':
|
||||
return self.randn_like
|
||||
|
||||
if hasattr(torch, item):
|
||||
return getattr(torch, item)
|
||||
|
||||
raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item))
|
||||
|
||||
def randn_like(self, x):
|
||||
if self.sampler_noises:
|
||||
noise = self.sampler_noises.popleft()
|
||||
if noise.shape == x.shape:
|
||||
return noise
|
||||
|
||||
if x.device.type == 'mps':
|
||||
return torch.randn_like(x, device=devices.cpu).to(x.device)
|
||||
else:
|
||||
return torch.randn_like(x)
|
||||
|
||||
|
||||
class KDiffusionSampler:
|
||||
def __init__(self, funcname, sd_model):
|
||||
denoiser = k_diffusion.external.CompVisVDenoiser if sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser
|
||||
|
||||
self.model_wrap = denoiser(sd_model, quantize=shared.opts.enable_quantization)
|
||||
self.funcname = funcname
|
||||
self.func = getattr(k_diffusion.sampling, self.funcname)
|
||||
self.extra_params = sampler_extra_params.get(funcname, [])
|
||||
self.model_wrap_cfg = CFGDenoiser(self.model_wrap)
|
||||
self.sampler_noises = None
|
||||
self.stop_at = None
|
||||
self.eta = None
|
||||
self.config = None
|
||||
self.last_latent = None
|
||||
|
||||
self.conditioning_key = sd_model.model.conditioning_key
|
||||
|
||||
def callback_state(self, d):
|
||||
step = d['i']
|
||||
latent = d["denoised"]
|
||||
if opts.live_preview_content == "Combined":
|
||||
sd_samplers_common.store_latent(latent)
|
||||
self.last_latent = latent
|
||||
|
||||
if self.stop_at is not None and step > self.stop_at:
|
||||
raise sd_samplers_common.InterruptedException
|
||||
|
||||
state.sampling_step = step
|
||||
shared.total_tqdm.update()
|
||||
|
||||
def launch_sampling(self, steps, func):
|
||||
state.sampling_steps = steps
|
||||
state.sampling_step = 0
|
||||
|
||||
try:
|
||||
return func()
|
||||
except sd_samplers_common.InterruptedException:
|
||||
return self.last_latent
|
||||
|
||||
def number_of_needed_noises(self, p):
|
||||
return p.steps
|
||||
|
||||
def initialize(self, p):
|
||||
self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None
|
||||
self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None
|
||||
self.model_wrap_cfg.step = 0
|
||||
self.model_wrap_cfg.image_cfg_scale = getattr(p, 'image_cfg_scale', None)
|
||||
self.eta = p.eta if p.eta is not None else opts.eta_ancestral
|
||||
|
||||
k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else [])
|
||||
|
||||
extra_params_kwargs = {}
|
||||
for param_name in self.extra_params:
|
||||
if hasattr(p, param_name) and param_name in inspect.signature(self.func).parameters:
|
||||
extra_params_kwargs[param_name] = getattr(p, param_name)
|
||||
|
||||
if 'eta' in inspect.signature(self.func).parameters:
|
||||
if self.eta != 1.0:
|
||||
p.extra_generation_params["Eta"] = self.eta
|
||||
|
||||
extra_params_kwargs['eta'] = self.eta
|
||||
|
||||
return extra_params_kwargs
|
||||
|
||||
def get_sigmas(self, p, steps):
|
||||
discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False)
|
||||
if opts.always_discard_next_to_last_sigma and not discard_next_to_last_sigma:
|
||||
discard_next_to_last_sigma = True
|
||||
p.extra_generation_params["Discard penultimate sigma"] = True
|
||||
|
||||
steps += 1 if discard_next_to_last_sigma else 0
|
||||
|
||||
if p.sampler_noise_scheduler_override:
|
||||
sigmas = p.sampler_noise_scheduler_override(steps)
|
||||
elif self.config is not None and self.config.options.get('scheduler', None) == 'karras':
|
||||
sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item())
|
||||
|
||||
sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, device=shared.device)
|
||||
else:
|
||||
sigmas = self.model_wrap.get_sigmas(steps)
|
||||
|
||||
if discard_next_to_last_sigma:
|
||||
sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])
|
||||
|
||||
return sigmas
|
||||
|
||||
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
||||
steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps)
|
||||
|
||||
sigmas = self.get_sigmas(p, steps)
|
||||
|
||||
sigma_sched = sigmas[steps - t_enc - 1:]
|
||||
xi = x + noise * sigma_sched[0]
|
||||
|
||||
extra_params_kwargs = self.initialize(p)
|
||||
if 'sigma_min' in inspect.signature(self.func).parameters:
|
||||
## last sigma is zero which isn't allowed by DPM Fast & Adaptive so taking value before last
|
||||
extra_params_kwargs['sigma_min'] = sigma_sched[-2]
|
||||
if 'sigma_max' in inspect.signature(self.func).parameters:
|
||||
extra_params_kwargs['sigma_max'] = sigma_sched[0]
|
||||
if 'n' in inspect.signature(self.func).parameters:
|
||||
extra_params_kwargs['n'] = len(sigma_sched) - 1
|
||||
if 'sigma_sched' in inspect.signature(self.func).parameters:
|
||||
extra_params_kwargs['sigma_sched'] = sigma_sched
|
||||
if 'sigmas' in inspect.signature(self.func).parameters:
|
||||
extra_params_kwargs['sigmas'] = sigma_sched
|
||||
|
||||
self.model_wrap_cfg.init_latent = x
|
||||
self.last_latent = x
|
||||
extra_args={
|
||||
'cond': conditioning,
|
||||
'image_cond': image_conditioning,
|
||||
'uncond': unconditional_conditioning,
|
||||
'cond_scale': p.cfg_scale,
|
||||
}
|
||||
|
||||
samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
||||
|
||||
return samples
|
||||
|
||||
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning = None):
|
||||
steps = steps or p.steps
|
||||
|
||||
sigmas = self.get_sigmas(p, steps)
|
||||
|
||||
x = x * sigmas[0]
|
||||
|
||||
extra_params_kwargs = self.initialize(p)
|
||||
if 'sigma_min' in inspect.signature(self.func).parameters:
|
||||
extra_params_kwargs['sigma_min'] = self.model_wrap.sigmas[0].item()
|
||||
extra_params_kwargs['sigma_max'] = self.model_wrap.sigmas[-1].item()
|
||||
if 'n' in inspect.signature(self.func).parameters:
|
||||
extra_params_kwargs['n'] = steps
|
||||
else:
|
||||
extra_params_kwargs['sigmas'] = sigmas
|
||||
|
||||
self.last_latent = x
|
||||
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={
|
||||
'cond': conditioning,
|
||||
'image_cond': image_conditioning,
|
||||
'uncond': unconditional_conditioning,
|
||||
'cond_scale': p.cfg_scale
|
||||
}, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
||||
|
||||
return samples
|
||||
|
|
@ -3,12 +3,13 @@ import safetensors.torch
|
|||
import os
|
||||
import collections
|
||||
from collections import namedtuple
|
||||
from modules import paths, shared, devices, script_callbacks, sd_models
|
||||
from modules import shared, devices, script_callbacks, sd_models
|
||||
from modules.paths import models_path
|
||||
import glob
|
||||
from copy import deepcopy
|
||||
|
||||
|
||||
vae_path = os.path.abspath(os.path.join(paths.models_path, "VAE"))
|
||||
vae_path = os.path.abspath(os.path.join(models_path, "VAE"))
|
||||
vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"}
|
||||
vae_dict = {}
|
||||
|
||||
|
@ -71,13 +72,6 @@ def refresh_vae_list():
|
|||
os.path.join(shared.cmd_opts.ckpt_dir, '**/*.vae.safetensors'),
|
||||
]
|
||||
|
||||
if shared.cmd_opts.vae_dir is not None and os.path.isdir(shared.cmd_opts.vae_dir):
|
||||
paths += [
|
||||
os.path.join(shared.cmd_opts.vae_dir, '**/*.ckpt'),
|
||||
os.path.join(shared.cmd_opts.vae_dir, '**/*.pt'),
|
||||
os.path.join(shared.cmd_opts.vae_dir, '**/*.safetensors'),
|
||||
]
|
||||
|
||||
candidates = []
|
||||
for path in paths:
|
||||
candidates += glob.iglob(path, recursive=True)
|
||||
|
@ -100,10 +94,8 @@ def resolve_vae(checkpoint_file):
|
|||
if shared.cmd_opts.vae_path is not None:
|
||||
return shared.cmd_opts.vae_path, 'from commandline argument'
|
||||
|
||||
is_automatic = shared.opts.sd_vae in {"Automatic", "auto"} # "auto" for people with old config
|
||||
|
||||
vae_near_checkpoint = find_vae_near_checkpoint(checkpoint_file)
|
||||
if vae_near_checkpoint is not None and (shared.opts.sd_vae_as_default or is_automatic):
|
||||
if vae_near_checkpoint is not None and (shared.opts.sd_vae_as_default or shared.opts.sd_vae == "Automatic"):
|
||||
return vae_near_checkpoint, 'found near the checkpoint'
|
||||
|
||||
if shared.opts.sd_vae == "None":
|
||||
|
@ -113,18 +105,12 @@ def resolve_vae(checkpoint_file):
|
|||
if vae_from_options is not None:
|
||||
return vae_from_options, 'specified in settings'
|
||||
|
||||
if not is_automatic:
|
||||
if shared.opts.sd_vae != "Automatic":
|
||||
print(f"Couldn't find VAE named {shared.opts.sd_vae}; using None instead")
|
||||
|
||||
return None, None
|
||||
|
||||
|
||||
def load_vae_dict(filename, map_location):
|
||||
vae_ckpt = sd_models.read_state_dict(filename, map_location=map_location)
|
||||
vae_dict_1 = {k: v for k, v in vae_ckpt.items() if k[0:4] != "loss" and k not in vae_ignore_keys}
|
||||
return vae_dict_1
|
||||
|
||||
|
||||
def load_vae(model, vae_file=None, vae_source="from unknown source"):
|
||||
global vae_dict, loaded_vae_file
|
||||
# save_settings = False
|
||||
|
@ -142,7 +128,8 @@ def load_vae(model, vae_file=None, vae_source="from unknown source"):
|
|||
print(f"Loading VAE weights {vae_source}: {vae_file}")
|
||||
store_base_vae(model)
|
||||
|
||||
vae_dict_1 = load_vae_dict(vae_file, map_location=shared.weight_load_location)
|
||||
vae_ckpt = sd_models.read_state_dict(vae_file, map_location=shared.weight_load_location)
|
||||
vae_dict_1 = {k: v for k, v in vae_ckpt.items() if k[0:4] != "loss" and k not in vae_ignore_keys}
|
||||
_load_vae_dict(model, vae_dict_1)
|
||||
|
||||
if cache_enabled:
|
||||
|
|
|
@ -9,34 +9,30 @@ from PIL import Image
|
|||
import gradio as gr
|
||||
import tqdm
|
||||
|
||||
import modules.artists
|
||||
import modules.interrogate
|
||||
import modules.memmon
|
||||
import modules.styles
|
||||
import modules.devices as devices
|
||||
from modules import localization, extensions, script_loading, errors, ui_components, shared_items
|
||||
from modules.paths import models_path, script_path, data_path
|
||||
from modules import localization, sd_vae, extensions, script_loading, errors, ui_components
|
||||
from modules.paths import models_path, script_path, sd_path
|
||||
|
||||
|
||||
demo = None
|
||||
|
||||
sd_configs_path = os.path.join(script_path, "configs")
|
||||
sd_default_config = os.path.join(sd_configs_path, "v1-inference.yaml")
|
||||
sd_model_file = os.path.join(script_path, 'model.ckpt')
|
||||
default_sd_model_file = sd_model_file
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored",)
|
||||
parser.add_argument("--config", type=str, default=sd_default_config, help="path to config which constructs model",)
|
||||
parser.add_argument("--config", type=str, default=os.path.join(script_path, "configs/v1-inference.yaml"), help="path to config which constructs model",)
|
||||
parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",)
|
||||
parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints")
|
||||
parser.add_argument("--vae-dir", type=str, default=None, help="Path to directory with VAE files")
|
||||
parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))
|
||||
parser.add_argument("--gfpgan-model", type=str, help="GFPGAN model file name", default=None)
|
||||
parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats")
|
||||
parser.add_argument("--no-half-vae", action='store_true', help="do not switch the VAE model to 16-bit floats")
|
||||
parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)")
|
||||
parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI")
|
||||
parser.add_argument("--embeddings-dir", type=str, default=os.path.join(data_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)")
|
||||
parser.add_argument("--embeddings-dir", type=str, default=os.path.join(script_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)")
|
||||
parser.add_argument("--textual-inversion-templates-dir", type=str, default=os.path.join(script_path, 'textual_inversion_templates'), help="directory with textual inversion templates")
|
||||
parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_path, 'hypernetworks'), help="hypernetwork directory")
|
||||
parser.add_argument("--localizations-dir", type=str, default=os.path.join(script_path, 'localizations'), help="localizations directory")
|
||||
|
@ -47,7 +43,6 @@ parser.add_argument("--lowram", action='store_true', help="load stable diffusion
|
|||
parser.add_argument("--always-batch-cond-uncond", action='store_true', help="disables cond/uncond batching that is enabled to save memory with --medvram or --lowvram")
|
||||
parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.")
|
||||
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
|
||||
parser.add_argument("--upcast-sampling", action='store_true', help="upcast sampling. No effect with --no-half. Usually produces similar results to --no-half with better performance while using less memory.")
|
||||
parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site")
|
||||
parser.add_argument("--ngrok", type=str, help="ngrok authtoken, alternative to gradio --share", default=None)
|
||||
parser.add_argument("--ngrok-region", type=str, help="The region in which ngrok should start.", default="us")
|
||||
|
@ -60,7 +55,6 @@ parser.add_argument("--realesrgan-models-path", type=str, help="Path to director
|
|||
parser.add_argument("--clip-models-path", type=str, help="Path to directory with CLIP model file(s).", default=None)
|
||||
parser.add_argument("--xformers", action='store_true', help="enable xformers for cross attention layers")
|
||||
parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work")
|
||||
parser.add_argument("--xformers-flash-attention", action='store_true', help="enable xformers with Flash Attention to improve reproducibility (supported for SD2.x or variant only)")
|
||||
parser.add_argument("--deepdanbooru", action='store_true', help="does not do anything")
|
||||
parser.add_argument("--opt-split-attention", action='store_true', help="force-enables Doggettx's cross-attention layer optimization. By default, it's on for torch cuda.")
|
||||
parser.add_argument("--opt-sub-quad-attention", action='store_true', help="enable memory efficient sub-quadratic cross-attention layer optimization")
|
||||
|
@ -70,21 +64,20 @@ parser.add_argument("--sub-quad-chunk-threshold", type=int, help="the percentage
|
|||
parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.")
|
||||
parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
|
||||
parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
|
||||
parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI")
|
||||
parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower)
|
||||
parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
|
||||
parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None)
|
||||
parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False)
|
||||
parser.add_argument("--ui-config-file", type=str, help="filename to use for ui configuration", default=os.path.join(data_path, 'ui-config.json'))
|
||||
parser.add_argument("--ui-config-file", type=str, help="filename to use for ui configuration", default=os.path.join(script_path, 'ui-config.json'))
|
||||
parser.add_argument("--hide-ui-dir-config", action='store_true', help="hide directory configuration from webui", default=False)
|
||||
parser.add_argument("--freeze-settings", action='store_true', help="disable editing settings", default=False)
|
||||
parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default=os.path.join(data_path, 'config.json'))
|
||||
parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default=os.path.join(script_path, 'config.json'))
|
||||
parser.add_argument("--gradio-debug", action='store_true', help="launch gradio with --debug option")
|
||||
parser.add_argument("--gradio-auth", type=str, help='set gradio authentication like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
|
||||
parser.add_argument("--gradio-img2img-tool", type=str, help='does not do anything')
|
||||
parser.add_argument("--gradio-inpaint-tool", type=str, help="does not do anything")
|
||||
parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last")
|
||||
parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(data_path, 'styles.csv'))
|
||||
parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(script_path, 'styles.csv'))
|
||||
parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False)
|
||||
parser.add_argument("--theme", type=str, help="launches the UI with light or dark theme", default=None)
|
||||
parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False)
|
||||
|
@ -104,10 +97,6 @@ parser.add_argument("--cors-allow-origins-regex", type=str, help="Allowed CORS o
|
|||
parser.add_argument("--tls-keyfile", type=str, help="Partially enables TLS, requires --tls-certfile to fully function", default=None)
|
||||
parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, requires --tls-keyfile to fully function", default=None)
|
||||
parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None)
|
||||
parser.add_argument("--gradio-queue", action='store_true', help="Uses gradio queue; experimental option; breaks restart UI button")
|
||||
parser.add_argument("--skip-version-check", action='store_true', help="Do not check versions of torch and xformers")
|
||||
parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False)
|
||||
|
||||
|
||||
script_loading.preload_extensions(extensions.extensions_dir, parser)
|
||||
script_loading.preload_extensions(extensions.extensions_builtin_dir, parser)
|
||||
|
@ -127,15 +116,13 @@ restricted_opts = {
|
|||
}
|
||||
|
||||
ui_reorder_categories = [
|
||||
"inpaint",
|
||||
"sampler",
|
||||
"checkboxes",
|
||||
"hires_fix",
|
||||
"dimensions",
|
||||
"cfg",
|
||||
"seed",
|
||||
"checkboxes",
|
||||
"hires_fix",
|
||||
"batch",
|
||||
"override_settings",
|
||||
"scripts",
|
||||
]
|
||||
|
||||
|
@ -154,7 +141,7 @@ config_filename = cmd_opts.ui_settings_file
|
|||
|
||||
os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True)
|
||||
hypernetworks = {}
|
||||
loaded_hypernetworks = []
|
||||
loaded_hypernetwork = None
|
||||
|
||||
|
||||
def reload_hypernetworks():
|
||||
|
@ -162,6 +149,8 @@ def reload_hypernetworks():
|
|||
global hypernetworks
|
||||
|
||||
hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir)
|
||||
hypernetwork.load_hypernetwork(opts.sd_hypernetwork)
|
||||
|
||||
|
||||
|
||||
class State:
|
||||
|
@ -239,7 +228,7 @@ class State:
|
|||
if not parallel_processing_allowed:
|
||||
return
|
||||
|
||||
if self.sampling_step - self.current_image_sampling_step >= opts.show_progress_every_n_steps and opts.live_previews_enable and opts.show_progress_every_n_steps != -1:
|
||||
if self.sampling_step - self.current_image_sampling_step >= opts.show_progress_every_n_steps and opts.live_previews_enable:
|
||||
self.do_set_current_image()
|
||||
|
||||
def do_set_current_image(self):
|
||||
|
@ -262,6 +251,8 @@ class State:
|
|||
state = State()
|
||||
state.server_start = time.time()
|
||||
|
||||
artist_db = modules.artists.ArtistsDatabase(os.path.join(script_path, 'artists.csv'))
|
||||
|
||||
styles_filename = cmd_opts.styles_file
|
||||
prompt_styles = modules.styles.StyleDatabase(styles_filename)
|
||||
|
||||
|
@ -269,6 +260,12 @@ interrogator = modules.interrogate.InterrogateModels("interrogate")
|
|||
|
||||
face_restorers = []
|
||||
|
||||
|
||||
def realesrgan_models_names():
|
||||
import modules.realesrgan_model
|
||||
return [x.name for x in modules.realesrgan_model.get_realesrgan_models(None)]
|
||||
|
||||
|
||||
class OptionInfo:
|
||||
def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, section=None, refresh=None):
|
||||
self.default = default
|
||||
|
@ -327,7 +324,7 @@ options_templates.update(options_section(('saving-images', "Saving images/grids"
|
|||
"jpeg_quality": OptionInfo(80, "Quality for saved jpeg images", gr.Slider, {"minimum": 1, "maximum": 100, "step": 1}),
|
||||
"export_for_4chan": OptionInfo(True, "If PNG image is larger than 4MB or any dimension is larger than 4000, downscale and save copy as JPG"),
|
||||
|
||||
"use_original_name_batch": OptionInfo(True, "Use original name for output filename during batch process in extras tab"),
|
||||
"use_original_name_batch": OptionInfo(False, "Use original name for output filename during batch process in extras tab"),
|
||||
"use_upscaler_name_as_suffix": OptionInfo(False, "Use upscaler name as filename suffix in the extras tab"),
|
||||
"save_selected_only": OptionInfo(True, "When using 'Save' button, only save a single selected image"),
|
||||
"do_not_add_watermark": OptionInfo(False, "Do not add watermark to images"),
|
||||
|
@ -349,17 +346,17 @@ options_templates.update(options_section(('saving-paths', "Paths for saving"), {
|
|||
}))
|
||||
|
||||
options_templates.update(options_section(('saving-to-dirs', "Saving to a directory"), {
|
||||
"save_to_dirs": OptionInfo(True, "Save images to a subdirectory"),
|
||||
"grid_save_to_dirs": OptionInfo(True, "Save grids to a subdirectory"),
|
||||
"save_to_dirs": OptionInfo(False, "Save images to a subdirectory"),
|
||||
"grid_save_to_dirs": OptionInfo(False, "Save grids to a subdirectory"),
|
||||
"use_save_to_dirs_for_ui": OptionInfo(False, "When using \"Save\" button, save images to a subdirectory"),
|
||||
"directories_filename_pattern": OptionInfo("[date]", "Directory name pattern", component_args=hide_dirs),
|
||||
"directories_filename_pattern": OptionInfo("", "Directory name pattern", component_args=hide_dirs),
|
||||
"directories_max_prompt_words": OptionInfo(8, "Max prompt words for [prompt_words] pattern", gr.Slider, {"minimum": 1, "maximum": 20, "step": 1, **hide_dirs}),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('upscaling', "Upscaling"), {
|
||||
"ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers. 0 = no tiling.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}),
|
||||
"ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for ESRGAN upscalers. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
|
||||
"realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": shared_items.realesrgan_models_names()}),
|
||||
"realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": realesrgan_models_names()}),
|
||||
"upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers]}),
|
||||
}))
|
||||
|
||||
|
@ -370,7 +367,6 @@ options_templates.update(options_section(('face-restoration', "Face restoration"
|
|||
}))
|
||||
|
||||
options_templates.update(options_section(('system', "System"), {
|
||||
"show_warnings": OptionInfo(False, "Show warnings in console."),
|
||||
"memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation. Set to 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}),
|
||||
"samples_log_stdout": OptionInfo(False, "Always print all generation info to standard output"),
|
||||
"multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."),
|
||||
|
@ -396,10 +392,12 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
|||
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints),
|
||||
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
||||
"sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
||||
"sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list),
|
||||
"sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
|
||||
"sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": ["Automatic", "None"] + list(sd_vae.vae_dict)}, refresh=sd_vae.refresh_vae_list),
|
||||
"sd_vae_as_default": OptionInfo(False, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
|
||||
"sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
|
||||
"sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}),
|
||||
"inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
||||
"initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01}),
|
||||
"initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01 }),
|
||||
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
|
||||
"img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."),
|
||||
"img2img_background_color": OptionInfo("#ffffff", "With img2img, fill image's transparent parts with this color.", ui_components.FormColorPicker, {}),
|
||||
|
@ -407,8 +405,8 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
|||
"enable_emphasis": OptionInfo(True, "Emphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"),
|
||||
"enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
|
||||
"comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }),
|
||||
"CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}),
|
||||
"upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"),
|
||||
'CLIP_stop_at_last_layers': OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}),
|
||||
"random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('compatibility', "Compatibility"), {
|
||||
|
@ -419,12 +417,12 @@ options_templates.update(options_section(('compatibility', "Compatibility"), {
|
|||
|
||||
options_templates.update(options_section(('interrogate', "Interrogate Options"), {
|
||||
"interrogate_keep_models_in_memory": OptionInfo(False, "Interrogate: keep models in VRAM"),
|
||||
"interrogate_use_builtin_artists": OptionInfo(True, "Interrogate: use artists from artists.csv"),
|
||||
"interrogate_return_ranks": OptionInfo(False, "Interrogate: include ranks of model tags matches in results (Has no effect on caption-based interrogators)."),
|
||||
"interrogate_clip_num_beams": OptionInfo(1, "Interrogate: num_beams for BLIP", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}),
|
||||
"interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum description length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}),
|
||||
"interrogate_clip_max_length": OptionInfo(48, "Interrogate: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}),
|
||||
"interrogate_clip_dict_limit": OptionInfo(1500, "CLIP: maximum number of lines in text file (0 = No limit)"),
|
||||
"interrogate_clip_skip_categories": OptionInfo([], "CLIP: skip inquire categories", gr.CheckboxGroup, lambda: {"choices": modules.interrogate.category_types()}, refresh=modules.interrogate.category_types),
|
||||
"interrogate_deepbooru_score_threshold": OptionInfo(0.5, "Interrogate: deepbooru score threshold", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}),
|
||||
"deepbooru_sort_alpha": OptionInfo(True, "Interrogate: deepbooru sort alphabetically"),
|
||||
"deepbooru_use_spaces": OptionInfo(False, "use spaces for tags in deepbooru"),
|
||||
|
@ -432,18 +430,12 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"),
|
|||
"deepbooru_filter_tags": OptionInfo("", "filter out those tags from deepbooru output (separated by comma)"),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('extra_networks', "Extra Networks"), {
|
||||
"extra_networks_default_view": OptionInfo("cards", "Default view for Extra Networks", gr.Dropdown, {"choices": ["cards", "thumbs"]}),
|
||||
"extra_networks_default_multiplier": OptionInfo(1.0, "Multiplier for extra networks", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
||||
"sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": [""] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('ui', "User interface"), {
|
||||
"return_grid": OptionInfo(True, "Show grid in results for web"),
|
||||
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
|
||||
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
|
||||
"add_model_name_to_info": OptionInfo(True, "Add model name to generation information"),
|
||||
"disable_weights_auto_swap": OptionInfo(True, "When reading generation parameters from text into UI (from PNG info or pasted text), do not change the selected model/checkpoint."),
|
||||
"disable_weights_auto_swap": OptionInfo(False, "When reading generation parameters from text into UI (from PNG info or pasted text), do not change the selected model/checkpoint."),
|
||||
"send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"),
|
||||
"send_size": OptionInfo(True, "Send size when sending prompt or image to another interface"),
|
||||
"font": OptionInfo("", "Font for image grids that have text"),
|
||||
|
@ -451,23 +443,18 @@ options_templates.update(options_section(('ui', "User interface"), {
|
|||
"js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"),
|
||||
"show_progress_in_title": OptionInfo(True, "Show generation progress in window title."),
|
||||
"samplers_in_dropdown": OptionInfo(True, "Use dropdown for sampler selection instead of radio group"),
|
||||
"dimensions_and_batch_together": OptionInfo(True, "Show Width/Height and Batch sliders in same row"),
|
||||
"keyedit_precision_attention": OptionInfo(0.1, "Ctrl+up/down precision when editing (attention:1.1)", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
|
||||
"keyedit_precision_extra": OptionInfo(0.05, "Ctrl+up/down precision when editing <extra networks:0.9>", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
|
||||
"quicksettings": OptionInfo("sd_model_checkpoint", "Quicksettings list"),
|
||||
"ui_reorder": OptionInfo(", ".join(ui_reorder_categories), "txt2img/img2img UI item order"),
|
||||
"ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order"),
|
||||
"localization": OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)),
|
||||
"dimensions_and_batch_together": OptionInfo(True, "Show Witdth/Height and Batch sliders in same row"),
|
||||
'quicksettings': OptionInfo("sd_model_checkpoint", "Quicksettings list"),
|
||||
'ui_reorder': OptionInfo(", ".join(ui_reorder_categories), "txt2img/img2img UI item order"),
|
||||
'localization': OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('ui', "Live previews"), {
|
||||
"show_progressbar": OptionInfo(True, "Show progressbar"),
|
||||
"live_previews_enable": OptionInfo(True, "Show live previews of the created image"),
|
||||
"show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"),
|
||||
"show_progress_every_n_steps": OptionInfo(10, "Show new live preview image every N sampling steps. Set to -1 to show after completion of batch.", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}),
|
||||
"show_progress_type": OptionInfo("Approx NN", "Image creation progress preview mode", gr.Radio, {"choices": ["Full", "Approx NN", "Approx cheap"]}),
|
||||
"live_preview_content": OptionInfo("Prompt", "Live preview subject", gr.Radio, {"choices": ["Combined", "Prompt", "Negative prompt"]}),
|
||||
"live_preview_refresh_period": OptionInfo(1000, "Progressbar/preview update period, in milliseconds")
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('sampler-params', "Sampler parameters"), {
|
||||
|
@ -482,12 +469,6 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters"
|
|||
'always_discard_next_to_last_sigma': OptionInfo(False, "Always discard next-to-last sigma"),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('postprocessing', "Postprocessing"), {
|
||||
'postprocessing_enable_in_main_ui': OptionInfo([], "Enable postprocessing operations in txt2img and img2img tabs", ui_components.DropdownMulti, lambda: {"choices": [x.name for x in shared_items.postprocessing_scripts()]}),
|
||||
'postprocessing_operation_order': OptionInfo([], "Postprocessing operation order", ui_components.DropdownMulti, lambda: {"choices": [x.name for x in shared_items.postprocessing_scripts()]}),
|
||||
'upscaling_max_images_in_cache': OptionInfo(5, "Maximum number of images in upscaling cache", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section((None, "Hidden options"), {
|
||||
"disabled_extensions": OptionInfo([], "Disable those extensions"),
|
||||
"sd_checkpoint_hash": OptionInfo("", "SHA256 hash of the current checkpoint"),
|
||||
|
@ -608,37 +589,11 @@ class Options:
|
|||
|
||||
self.data_labels = {k: v for k, v in sorted(settings_items, key=lambda x: section_ids[x[1].section])}
|
||||
|
||||
def cast_value(self, key, value):
|
||||
"""casts an arbitrary to the same type as this setting's value with key
|
||||
Example: cast_value("eta_noise_seed_delta", "12") -> returns 12 (an int rather than str)
|
||||
"""
|
||||
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
default_value = self.data_labels[key].default
|
||||
if default_value is None:
|
||||
default_value = getattr(self, key, None)
|
||||
if default_value is None:
|
||||
return None
|
||||
|
||||
expected_type = type(default_value)
|
||||
if expected_type == bool and value == "False":
|
||||
value = False
|
||||
else:
|
||||
value = expected_type(value)
|
||||
|
||||
return value
|
||||
|
||||
|
||||
|
||||
opts = Options()
|
||||
if os.path.exists(config_filename):
|
||||
opts.load(config_filename)
|
||||
|
||||
settings_components = None
|
||||
"""assinged from ui.py, a mapping on setting anmes to gradio components repsponsible for those settings"""
|
||||
|
||||
latent_upscale_default_mode = "Latent"
|
||||
latent_upscale_modes = {
|
||||
"Latent": {"mode": "bilinear", "antialias": False},
|
||||
|
@ -699,17 +654,3 @@ mem_mon.start()
|
|||
def listfiles(dirname):
|
||||
filenames = [os.path.join(dirname, x) for x in sorted(os.listdir(dirname)) if not x.startswith(".")]
|
||||
return [file for file in filenames if os.path.isfile(file)]
|
||||
|
||||
|
||||
def html_path(filename):
|
||||
return os.path.join(script_path, "html", filename)
|
||||
|
||||
|
||||
def html(filename):
|
||||
path = html_path(filename)
|
||||
|
||||
if os.path.exists(path):
|
||||
with open(path, encoding="utf8") as file:
|
||||
return file.read()
|
||||
|
||||
return ""
|
||||
|
|
|
@ -1,23 +0,0 @@
|
|||
|
||||
|
||||
def realesrgan_models_names():
|
||||
import modules.realesrgan_model
|
||||
return [x.name for x in modules.realesrgan_model.get_realesrgan_models(None)]
|
||||
|
||||
|
||||
def postprocessing_scripts():
|
||||
import modules.scripts
|
||||
|
||||
return modules.scripts.scripts_postproc.scripts
|
||||
|
||||
|
||||
def sd_vae_items():
|
||||
import modules.sd_vae
|
||||
|
||||
return ["Automatic", "None"] + list(modules.sd_vae.vae_dict)
|
||||
|
||||
|
||||
def refresh_vae_list():
|
||||
import modules.sd_vae
|
||||
|
||||
return modules.sd_vae.refresh_vae_list
|
|
@ -40,18 +40,12 @@ def apply_styles_to_prompt(prompt, styles):
|
|||
class StyleDatabase:
|
||||
def __init__(self, path: str):
|
||||
self.no_style = PromptStyle("None", "", "")
|
||||
self.styles = {}
|
||||
self.path = path
|
||||
self.styles = {"None": self.no_style}
|
||||
|
||||
self.reload()
|
||||
|
||||
def reload(self):
|
||||
self.styles.clear()
|
||||
|
||||
if not os.path.exists(self.path):
|
||||
if not os.path.exists(path):
|
||||
return
|
||||
|
||||
with open(self.path, "r", encoding="utf-8-sig", newline='') as file:
|
||||
with open(path, "r", encoding="utf-8-sig", newline='') as file:
|
||||
reader = csv.DictReader(file)
|
||||
for row in reader:
|
||||
# Support loading old CSV format with "name, text"-columns
|
||||
|
|
|
@ -67,7 +67,7 @@ def _summarize_chunk(
|
|||
max_score, _ = torch.max(attn_weights, -1, keepdim=True)
|
||||
max_score = max_score.detach()
|
||||
exp_weights = torch.exp(attn_weights - max_score)
|
||||
exp_values = torch.bmm(exp_weights, value) if query.device.type == 'mps' else torch.bmm(exp_weights, value.to(exp_weights.dtype)).to(value.dtype)
|
||||
exp_values = torch.bmm(exp_weights, value)
|
||||
max_score = max_score.squeeze(-1)
|
||||
return AttnChunk(exp_values, exp_weights.sum(dim=-1), max_score)
|
||||
|
||||
|
@ -129,7 +129,7 @@ def _get_attention_scores_no_kv_chunking(
|
|||
)
|
||||
attn_probs = attn_scores.softmax(dim=-1)
|
||||
del attn_scores
|
||||
hidden_states_slice = torch.bmm(attn_probs, value) if query.device.type == 'mps' else torch.bmm(attn_probs, value.to(attn_probs.dtype)).to(value.dtype)
|
||||
hidden_states_slice = torch.bmm(attn_probs, value)
|
||||
return hidden_states_slice
|
||||
|
||||
|
||||
|
|
|
@ -2,7 +2,7 @@ import datetime
|
|||
import json
|
||||
import os
|
||||
|
||||
saved_params_shared = {"model_name", "model_hash", "initial_step", "num_of_dataset_images", "learn_rate", "batch_size", "clip_grad_mode", "clip_grad_value", "gradient_step", "data_root", "log_directory", "training_width", "training_height", "steps", "create_image_every", "template_file", "gradient_step", "latent_sampling_method"}
|
||||
saved_params_shared = {"model_name", "model_hash", "initial_step", "num_of_dataset_images", "learn_rate", "batch_size", "clip_grad_mode", "clip_grad_value", "gradient_step", "data_root", "log_directory", "training_width", "training_height", "steps", "create_image_every", "template_file"}
|
||||
saved_params_ti = {"embedding_name", "num_vectors_per_token", "save_embedding_every", "save_image_with_stored_embedding"}
|
||||
saved_params_hypernet = {"hypernetwork_name", "layer_structure", "activation_func", "weight_init", "add_layer_norm", "use_dropout", "save_hypernetwork_every"}
|
||||
saved_params_all = saved_params_shared | saved_params_ti | saved_params_hypernet
|
||||
|
|
|
@ -6,12 +6,13 @@ import sys
|
|||
import tqdm
|
||||
import time
|
||||
|
||||
from modules import paths, shared, images, deepbooru
|
||||
from modules import shared, images, deepbooru
|
||||
from modules.paths import models_path
|
||||
from modules.shared import opts, cmd_opts
|
||||
from modules.textual_inversion import autocrop
|
||||
|
||||
|
||||
def preprocess(id_task, process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False, process_multicrop=None, process_multicrop_mindim=None, process_multicrop_maxdim=None, process_multicrop_minarea=None, process_multicrop_maxarea=None, process_multicrop_objective=None, process_multicrop_threshold=None):
|
||||
def preprocess(id_task, process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False):
|
||||
try:
|
||||
if process_caption:
|
||||
shared.interrogator.load()
|
||||
|
@ -19,7 +20,7 @@ def preprocess(id_task, process_src, process_dst, process_width, process_height,
|
|||
if process_caption_deepbooru:
|
||||
deepbooru.model.start()
|
||||
|
||||
preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru, split_threshold, overlap_ratio, process_focal_crop, process_focal_crop_face_weight, process_focal_crop_entropy_weight, process_focal_crop_edges_weight, process_focal_crop_debug, process_multicrop, process_multicrop_mindim, process_multicrop_maxdim, process_multicrop_minarea, process_multicrop_maxarea, process_multicrop_objective, process_multicrop_threshold)
|
||||
preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru, split_threshold, overlap_ratio, process_focal_crop, process_focal_crop_face_weight, process_focal_crop_entropy_weight, process_focal_crop_edges_weight, process_focal_crop_debug)
|
||||
|
||||
finally:
|
||||
|
||||
|
@ -108,30 +109,8 @@ def split_pic(image, inverse_xy, width, height, overlap_ratio):
|
|||
splitted = image.crop((0, y, to_w, y + to_h))
|
||||
yield splitted
|
||||
|
||||
# not using torchvision.transforms.CenterCrop because it doesn't allow float regions
|
||||
def center_crop(image: Image, w: int, h: int):
|
||||
iw, ih = image.size
|
||||
if ih / h < iw / w:
|
||||
sw = w * ih / h
|
||||
box = (iw - sw) / 2, 0, iw - (iw - sw) / 2, ih
|
||||
else:
|
||||
sh = h * iw / w
|
||||
box = 0, (ih - sh) / 2, iw, ih - (ih - sh) / 2
|
||||
return image.resize((w, h), Image.Resampling.LANCZOS, box)
|
||||
|
||||
|
||||
def multicrop_pic(image: Image, mindim, maxdim, minarea, maxarea, objective, threshold):
|
||||
iw, ih = image.size
|
||||
err = lambda w, h: 1-(lambda x: x if x < 1 else 1/x)(iw/ih/(w/h))
|
||||
wh = max(((w, h) for w in range(mindim, maxdim+1, 64) for h in range(mindim, maxdim+1, 64)
|
||||
if minarea <= w * h <= maxarea and err(w, h) <= threshold),
|
||||
key= lambda wh: (wh[0]*wh[1], -err(*wh))[::1 if objective=='Maximize area' else -1],
|
||||
default=None
|
||||
)
|
||||
return wh and center_crop(image, *wh)
|
||||
|
||||
|
||||
def preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False, process_multicrop=None, process_multicrop_mindim=None, process_multicrop_maxdim=None, process_multicrop_minarea=None, process_multicrop_maxarea=None, process_multicrop_objective=None, process_multicrop_threshold=None):
|
||||
def preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False):
|
||||
width = process_width
|
||||
height = process_height
|
||||
src = os.path.abspath(process_src)
|
||||
|
@ -198,7 +177,7 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre
|
|||
|
||||
dnn_model_path = None
|
||||
try:
|
||||
dnn_model_path = autocrop.download_and_cache_models(os.path.join(paths.models_path, "opencv"))
|
||||
dnn_model_path = autocrop.download_and_cache_models(os.path.join(models_path, "opencv"))
|
||||
except Exception as e:
|
||||
print("Unable to load face detection model for auto crop selection. Falling back to lower quality haar method.", e)
|
||||
|
||||
|
@ -215,14 +194,6 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre
|
|||
save_pic(focal, index, params, existing_caption=existing_caption)
|
||||
process_default_resize = False
|
||||
|
||||
if process_multicrop:
|
||||
cropped = multicrop_pic(img, process_multicrop_mindim, process_multicrop_maxdim, process_multicrop_minarea, process_multicrop_maxarea, process_multicrop_objective, process_multicrop_threshold)
|
||||
if cropped is not None:
|
||||
save_pic(cropped, index, params, existing_caption=existing_caption)
|
||||
else:
|
||||
print(f"skipped {img.width}x{img.height} image {filename} (can't find suitable size within error threshold)")
|
||||
process_default_resize = False
|
||||
|
||||
if process_default_resize:
|
||||
img = images.resize_image(1, img, width, height)
|
||||
save_pic(img, index, params, existing_caption=existing_caption)
|
||||
|
|
|
@ -15,7 +15,7 @@ import numpy as np
|
|||
from PIL import Image, PngImagePlugin
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers, sd_hijack_checkpoint
|
||||
from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers
|
||||
import modules.textual_inversion.dataset
|
||||
from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
||||
|
||||
|
@ -50,7 +50,6 @@ class Embedding:
|
|||
self.sd_checkpoint = None
|
||||
self.sd_checkpoint_name = None
|
||||
self.optimizer_state_dict = None
|
||||
self.filename = None
|
||||
|
||||
def save(self, filename):
|
||||
embedding_data = {
|
||||
|
@ -112,7 +111,6 @@ class EmbeddingDatabase:
|
|||
self.skipped_embeddings = {}
|
||||
self.expected_shape = -1
|
||||
self.embedding_dirs = {}
|
||||
self.previously_displayed_embeddings = ()
|
||||
|
||||
def add_embedding_dir(self, path):
|
||||
self.embedding_dirs[path] = DirWithTextualInversionEmbeddings(path)
|
||||
|
@ -184,7 +182,6 @@ class EmbeddingDatabase:
|
|||
embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
|
||||
embedding.vectors = vec.shape[0]
|
||||
embedding.shape = vec.shape[-1]
|
||||
embedding.filename = path
|
||||
|
||||
if self.expected_shape == -1 or self.expected_shape == embedding.shape:
|
||||
self.register_embedding(embedding, shared.sd_model)
|
||||
|
@ -195,7 +192,7 @@ class EmbeddingDatabase:
|
|||
if not os.path.isdir(embdir.path):
|
||||
return
|
||||
|
||||
for root, dirs, fns in os.walk(embdir.path, followlinks=True):
|
||||
for root, dirs, fns in os.walk(embdir.path):
|
||||
for fn in fns:
|
||||
try:
|
||||
fullfn = os.path.join(root, fn)
|
||||
|
@ -229,12 +226,9 @@ class EmbeddingDatabase:
|
|||
self.load_from_dir(embdir)
|
||||
embdir.update()
|
||||
|
||||
displayed_embeddings = (tuple(self.word_embeddings.keys()), tuple(self.skipped_embeddings.keys()))
|
||||
if self.previously_displayed_embeddings != displayed_embeddings:
|
||||
self.previously_displayed_embeddings = displayed_embeddings
|
||||
print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}")
|
||||
if len(self.skipped_embeddings) > 0:
|
||||
print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings.keys())}")
|
||||
print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}")
|
||||
if len(self.skipped_embeddings) > 0:
|
||||
print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings.keys())}")
|
||||
|
||||
def find_embedding_at_position(self, tokens, offset):
|
||||
token = tokens[offset]
|
||||
|
@ -458,8 +452,6 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
|
|||
|
||||
pbar = tqdm.tqdm(total=steps - initial_step)
|
||||
try:
|
||||
sd_hijack_checkpoint.add()
|
||||
|
||||
for i in range((steps-initial_step) * gradient_step):
|
||||
if scheduler.finished:
|
||||
break
|
||||
|
@ -625,11 +617,9 @@ Last saved image: {html.escape(last_saved_image)}<br/>
|
|||
pbar.close()
|
||||
shared.sd_model.first_stage_model.to(devices.device)
|
||||
shared.parallel_processing_allowed = old_parallel_processing_allowed
|
||||
sd_hijack_checkpoint.remove()
|
||||
|
||||
return embedding, filename
|
||||
|
||||
|
||||
def save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True):
|
||||
old_embedding_name = embedding.name
|
||||
old_sd_checkpoint = embedding.sd_checkpoint if hasattr(embedding, "sd_checkpoint") else None
|
||||
|
|
|
@ -1,35 +0,0 @@
|
|||
import time
|
||||
|
||||
|
||||
class Timer:
|
||||
def __init__(self):
|
||||
self.start = time.time()
|
||||
self.records = {}
|
||||
self.total = 0
|
||||
|
||||
def elapsed(self):
|
||||
end = time.time()
|
||||
res = end - self.start
|
||||
self.start = end
|
||||
return res
|
||||
|
||||
def record(self, category, extra_time=0):
|
||||
e = self.elapsed()
|
||||
if category not in self.records:
|
||||
self.records[category] = 0
|
||||
|
||||
self.records[category] += e + extra_time
|
||||
self.total += e + extra_time
|
||||
|
||||
def summary(self):
|
||||
res = f"{self.total:.1f}s"
|
||||
|
||||
additions = [x for x in self.records.items() if x[1] >= 0.1]
|
||||
if not additions:
|
||||
return res
|
||||
|
||||
res += " ("
|
||||
res += ", ".join([f"{category}: {time_taken:.1f}s" for category, time_taken in additions])
|
||||
res += ")"
|
||||
|
||||
return res
|
|
@ -1,6 +1,5 @@
|
|||
import modules.scripts
|
||||
from modules import sd_samplers
|
||||
from modules.generation_parameters_copypaste import create_override_settings_dict
|
||||
from modules.processing import StableDiffusionProcessing, Processed, StableDiffusionProcessingTxt2Img, \
|
||||
StableDiffusionProcessingImg2Img, process_images
|
||||
from modules.shared import opts, cmd_opts
|
||||
|
@ -9,15 +8,13 @@ import modules.processing as processing
|
|||
from modules.ui import plaintext_to_html
|
||||
|
||||
|
||||
def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, override_settings_texts, *args):
|
||||
override_settings = create_override_settings_dict(override_settings_texts)
|
||||
|
||||
def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, *args):
|
||||
p = StableDiffusionProcessingTxt2Img(
|
||||
sd_model=shared.sd_model,
|
||||
outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
|
||||
outpath_grids=opts.outdir_grids or opts.outdir_txt2img_grids,
|
||||
prompt=prompt,
|
||||
styles=prompt_styles,
|
||||
styles=[prompt_style, prompt_style2],
|
||||
negative_prompt=negative_prompt,
|
||||
seed=seed,
|
||||
subseed=subseed,
|
||||
|
@ -41,7 +38,6 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, step
|
|||
hr_second_pass_steps=hr_second_pass_steps,
|
||||
hr_resize_x=hr_resize_x,
|
||||
hr_resize_y=hr_resize_y,
|
||||
override_settings=override_settings,
|
||||
)
|
||||
|
||||
p.scripts = modules.scripts.scripts_txt2img
|
||||
|
|
747
modules/ui.py
747
modules/ui.py
File diff suppressed because it is too large
Load Diff
|
@ -1,206 +0,0 @@
|
|||
import json
|
||||
import html
|
||||
import os
|
||||
import platform
|
||||
import sys
|
||||
|
||||
import gradio as gr
|
||||
import subprocess as sp
|
||||
|
||||
from modules import call_queue, shared
|
||||
from modules.generation_parameters_copypaste import image_from_url_text
|
||||
import modules.images
|
||||
|
||||
folder_symbol = '\U0001f4c2' # 📂
|
||||
|
||||
|
||||
def update_generation_info(generation_info, html_info, img_index):
|
||||
try:
|
||||
generation_info = json.loads(generation_info)
|
||||
if img_index < 0 or img_index >= len(generation_info["infotexts"]):
|
||||
return html_info, gr.update()
|
||||
return plaintext_to_html(generation_info["infotexts"][img_index]), gr.update()
|
||||
except Exception:
|
||||
pass
|
||||
# if the json parse or anything else fails, just return the old html_info
|
||||
return html_info, gr.update()
|
||||
|
||||
|
||||
def plaintext_to_html(text):
|
||||
text = "<p>" + "<br>\n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "</p>"
|
||||
return text
|
||||
|
||||
|
||||
def save_files(js_data, images, do_make_zip, index):
|
||||
import csv
|
||||
filenames = []
|
||||
fullfns = []
|
||||
|
||||
#quick dictionary to class object conversion. Its necessary due apply_filename_pattern requiring it
|
||||
class MyObject:
|
||||
def __init__(self, d=None):
|
||||
if d is not None:
|
||||
for key, value in d.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
data = json.loads(js_data)
|
||||
|
||||
p = MyObject(data)
|
||||
path = shared.opts.outdir_save
|
||||
save_to_dirs = shared.opts.use_save_to_dirs_for_ui
|
||||
extension: str = shared.opts.samples_format
|
||||
start_index = 0
|
||||
|
||||
if index > -1 and shared.opts.save_selected_only and (index >= data["index_of_first_image"]): # ensures we are looking at a specific non-grid picture, and we have save_selected_only
|
||||
|
||||
images = [images[index]]
|
||||
start_index = index
|
||||
|
||||
os.makedirs(shared.opts.outdir_save, exist_ok=True)
|
||||
|
||||
with open(os.path.join(shared.opts.outdir_save, "log.csv"), "a", encoding="utf8", newline='') as file:
|
||||
at_start = file.tell() == 0
|
||||
writer = csv.writer(file)
|
||||
if at_start:
|
||||
writer.writerow(["prompt", "seed", "width", "height", "sampler", "cfgs", "steps", "filename", "negative_prompt"])
|
||||
|
||||
for image_index, filedata in enumerate(images, start_index):
|
||||
image = image_from_url_text(filedata)
|
||||
|
||||
is_grid = image_index < p.index_of_first_image
|
||||
i = 0 if is_grid else (image_index - p.index_of_first_image)
|
||||
|
||||
fullfn, txt_fullfn = modules.images.save_image(image, path, "", seed=p.all_seeds[i], prompt=p.all_prompts[i], extension=extension, info=p.infotexts[image_index], grid=is_grid, p=p, save_to_dirs=save_to_dirs)
|
||||
|
||||
filename = os.path.relpath(fullfn, path)
|
||||
filenames.append(filename)
|
||||
fullfns.append(fullfn)
|
||||
if txt_fullfn:
|
||||
filenames.append(os.path.basename(txt_fullfn))
|
||||
fullfns.append(txt_fullfn)
|
||||
|
||||
writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler_name"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]])
|
||||
|
||||
# Make Zip
|
||||
if do_make_zip:
|
||||
zip_filepath = os.path.join(path, "images.zip")
|
||||
|
||||
from zipfile import ZipFile
|
||||
with ZipFile(zip_filepath, "w") as zip_file:
|
||||
for i in range(len(fullfns)):
|
||||
with open(fullfns[i], mode="rb") as f:
|
||||
zip_file.writestr(filenames[i], f.read())
|
||||
fullfns.insert(0, zip_filepath)
|
||||
|
||||
return gr.File.update(value=fullfns, visible=True), plaintext_to_html(f"Saved: {filenames[0]}")
|
||||
|
||||
|
||||
def create_output_panel(tabname, outdir):
|
||||
from modules import shared
|
||||
import modules.generation_parameters_copypaste as parameters_copypaste
|
||||
|
||||
def open_folder(f):
|
||||
if not os.path.exists(f):
|
||||
print(f'Folder "{f}" does not exist. After you create an image, the folder will be created.')
|
||||
return
|
||||
elif not os.path.isdir(f):
|
||||
print(f"""
|
||||
WARNING
|
||||
An open_folder request was made with an argument that is not a folder.
|
||||
This could be an error or a malicious attempt to run code on your computer.
|
||||
Requested path was: {f}
|
||||
""", file=sys.stderr)
|
||||
return
|
||||
|
||||
if not shared.cmd_opts.hide_ui_dir_config:
|
||||
path = os.path.normpath(f)
|
||||
if platform.system() == "Windows":
|
||||
os.startfile(path)
|
||||
elif platform.system() == "Darwin":
|
||||
sp.Popen(["open", path])
|
||||
elif "microsoft-standard-WSL2" in platform.uname().release:
|
||||
sp.Popen(["wsl-open", path])
|
||||
else:
|
||||
sp.Popen(["xdg-open", path])
|
||||
|
||||
with gr.Column(variant='panel', elem_id=f"{tabname}_results"):
|
||||
with gr.Group(elem_id=f"{tabname}_gallery_container"):
|
||||
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery").style(grid=4)
|
||||
|
||||
generation_info = None
|
||||
with gr.Column():
|
||||
with gr.Row(elem_id=f"image_buttons_{tabname}"):
|
||||
open_folder_button = gr.Button(folder_symbol, elem_id="hidden_element" if shared.cmd_opts.hide_ui_dir_config else f'open_folder_{tabname}')
|
||||
|
||||
if tabname != "extras":
|
||||
save = gr.Button('Save', elem_id=f'save_{tabname}')
|
||||
save_zip = gr.Button('Zip', elem_id=f'save_zip_{tabname}')
|
||||
|
||||
buttons = parameters_copypaste.create_buttons(["img2img", "inpaint", "extras"])
|
||||
|
||||
open_folder_button.click(
|
||||
fn=lambda: open_folder(shared.opts.outdir_samples or outdir),
|
||||
inputs=[],
|
||||
outputs=[],
|
||||
)
|
||||
|
||||
if tabname != "extras":
|
||||
with gr.Row():
|
||||
download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False, elem_id=f'download_files_{tabname}')
|
||||
|
||||
with gr.Group():
|
||||
html_info = gr.HTML(elem_id=f'html_info_{tabname}')
|
||||
html_log = gr.HTML(elem_id=f'html_log_{tabname}')
|
||||
|
||||
generation_info = gr.Textbox(visible=False, elem_id=f'generation_info_{tabname}')
|
||||
if tabname == 'txt2img' or tabname == 'img2img':
|
||||
generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button")
|
||||
generation_info_button.click(
|
||||
fn=update_generation_info,
|
||||
_js="function(x, y, z){ return [x, y, selected_gallery_index()] }",
|
||||
inputs=[generation_info, html_info, html_info],
|
||||
outputs=[html_info, html_info],
|
||||
)
|
||||
|
||||
save.click(
|
||||
fn=call_queue.wrap_gradio_call(save_files),
|
||||
_js="(x, y, z, w) => [x, y, false, selected_gallery_index()]",
|
||||
inputs=[
|
||||
generation_info,
|
||||
result_gallery,
|
||||
html_info,
|
||||
html_info,
|
||||
],
|
||||
outputs=[
|
||||
download_files,
|
||||
html_log,
|
||||
],
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
save_zip.click(
|
||||
fn=call_queue.wrap_gradio_call(save_files),
|
||||
_js="(x, y, z, w) => [x, y, true, selected_gallery_index()]",
|
||||
inputs=[
|
||||
generation_info,
|
||||
result_gallery,
|
||||
html_info,
|
||||
html_info,
|
||||
],
|
||||
outputs=[
|
||||
download_files,
|
||||
html_log,
|
||||
]
|
||||
)
|
||||
|
||||
else:
|
||||
html_info_x = gr.HTML(elem_id=f'html_info_x_{tabname}')
|
||||
html_info = gr.HTML(elem_id=f'html_info_{tabname}')
|
||||
html_log = gr.HTML(elem_id=f'html_log_{tabname}')
|
||||
|
||||
for paste_tabname, paste_button in buttons.items():
|
||||
parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(
|
||||
paste_button=paste_button, tabname=paste_tabname, source_tabname="txt2img" if tabname == "txt2img" else None, source_image_component=result_gallery
|
||||
))
|
||||
|
||||
return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log
|
|
@ -11,16 +11,6 @@ class ToolButton(gr.Button, gr.components.FormComponent):
|
|||
return "button"
|
||||
|
||||
|
||||
class ToolButtonTop(gr.Button, gr.components.FormComponent):
|
||||
"""Small button with single emoji as text, with extra margin at top, fits inside gradio forms"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(variant="tool-top", **kwargs)
|
||||
|
||||
def get_block_name(self):
|
||||
return "button"
|
||||
|
||||
|
||||
class FormRow(gr.Row, gr.components.FormComponent):
|
||||
"""Same as gr.Row but fits inside gradio forms"""
|
||||
|
||||
|
@ -47,12 +37,3 @@ class FormColorPicker(gr.ColorPicker, gr.components.FormComponent):
|
|||
|
||||
def get_block_name(self):
|
||||
return "colorpicker"
|
||||
|
||||
|
||||
class DropdownMulti(gr.Dropdown):
|
||||
"""Same as gr.Dropdown but always multiselect"""
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(multiselect=True, **kwargs)
|
||||
|
||||
def get_block_name(self):
|
||||
return "dropdown"
|
||||
|
|
|
@ -13,7 +13,7 @@ import shutil
|
|||
import errno
|
||||
|
||||
from modules import extensions, shared, paths
|
||||
from modules.call_queue import wrap_gradio_gpu_call
|
||||
|
||||
|
||||
available_extensions = {"extensions": []}
|
||||
|
||||
|
@ -50,17 +50,12 @@ def apply_and_restart(disable_list, update_list):
|
|||
shared.state.need_restart = True
|
||||
|
||||
|
||||
def check_updates(id_task, disable_list):
|
||||
def check_updates():
|
||||
check_access()
|
||||
|
||||
disabled = json.loads(disable_list)
|
||||
assert type(disabled) == list, f"wrong disable_list data for apply_and_restart: {disable_list}"
|
||||
|
||||
exts = [ext for ext in extensions.extensions if ext.remote is not None and ext.name not in disabled]
|
||||
shared.state.job_count = len(exts)
|
||||
|
||||
for ext in exts:
|
||||
shared.state.textinfo = ext.name
|
||||
for ext in extensions.extensions:
|
||||
if ext.remote is None:
|
||||
continue
|
||||
|
||||
try:
|
||||
ext.check_updates()
|
||||
|
@ -68,9 +63,7 @@ def check_updates(id_task, disable_list):
|
|||
print(f"Error checking updates for {ext.name}:", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
shared.state.nextjob()
|
||||
|
||||
return extension_table(), ""
|
||||
return extension_table()
|
||||
|
||||
|
||||
def extension_table():
|
||||
|
@ -139,7 +132,7 @@ def install_extension_from_url(dirname, url):
|
|||
normalized_url = normalize_git_url(url)
|
||||
assert len([x for x in extensions.extensions if normalize_git_url(x.remote) == normalized_url]) == 0, 'Extension with this URL is already installed'
|
||||
|
||||
tmpdir = os.path.join(paths.data_path, "tmp", dirname)
|
||||
tmpdir = os.path.join(paths.script_path, "tmp", dirname)
|
||||
|
||||
try:
|
||||
shutil.rmtree(tmpdir, True)
|
||||
|
@ -280,13 +273,12 @@ def create_ui():
|
|||
with gr.Tabs(elem_id="tabs_extensions") as tabs:
|
||||
with gr.TabItem("Installed"):
|
||||
|
||||
with gr.Row(elem_id="extensions_installed_top"):
|
||||
with gr.Row():
|
||||
apply = gr.Button(value="Apply and restart UI", variant="primary")
|
||||
check = gr.Button(value="Check for updates")
|
||||
extensions_disabled_list = gr.Text(elem_id="extensions_disabled_list", visible=False).style(container=False)
|
||||
extensions_update_list = gr.Text(elem_id="extensions_update_list", visible=False).style(container=False)
|
||||
|
||||
info = gr.HTML()
|
||||
extensions_table = gr.HTML(lambda: extension_table())
|
||||
|
||||
apply.click(
|
||||
|
@ -297,10 +289,10 @@ def create_ui():
|
|||
)
|
||||
|
||||
check.click(
|
||||
fn=wrap_gradio_gpu_call(check_updates, extra_outputs=[gr.update()]),
|
||||
fn=check_updates,
|
||||
_js="extensions_check",
|
||||
inputs=[info, extensions_disabled_list],
|
||||
outputs=[extensions_table, info],
|
||||
inputs=[],
|
||||
outputs=[extensions_table],
|
||||
)
|
||||
|
||||
with gr.TabItem("Available"):
|
||||
|
|
|
@ -1,245 +0,0 @@
|
|||
import glob
|
||||
import os.path
|
||||
import urllib.parse
|
||||
from pathlib import Path
|
||||
|
||||
from modules import shared
|
||||
import gradio as gr
|
||||
import json
|
||||
import html
|
||||
|
||||
from modules.generation_parameters_copypaste import image_from_url_text
|
||||
|
||||
extra_pages = []
|
||||
allowed_dirs = set()
|
||||
|
||||
|
||||
def register_page(page):
|
||||
"""registers extra networks page for the UI; recommend doing it in on_before_ui() callback for extensions"""
|
||||
|
||||
extra_pages.append(page)
|
||||
allowed_dirs.clear()
|
||||
allowed_dirs.update(set(sum([x.allowed_directories_for_previews() for x in extra_pages], [])))
|
||||
|
||||
|
||||
def add_pages_to_demo(app):
|
||||
def fetch_file(filename: str = ""):
|
||||
from starlette.responses import FileResponse
|
||||
|
||||
if not any([Path(x).absolute() in Path(filename).absolute().parents for x in allowed_dirs]):
|
||||
raise ValueError(f"File cannot be fetched: {filename}. Must be in one of directories registered by extra pages.")
|
||||
|
||||
ext = os.path.splitext(filename)[1].lower()
|
||||
if ext not in (".png", ".jpg"):
|
||||
raise ValueError(f"File cannot be fetched: {filename}. Only png and jpg.")
|
||||
|
||||
# would profit from returning 304
|
||||
return FileResponse(filename, headers={"Accept-Ranges": "bytes"})
|
||||
|
||||
app.add_api_route("/sd_extra_networks/thumb", fetch_file, methods=["GET"])
|
||||
|
||||
|
||||
class ExtraNetworksPage:
|
||||
def __init__(self, title):
|
||||
self.title = title
|
||||
self.name = title.lower()
|
||||
self.card_page = shared.html("extra-networks-card.html")
|
||||
self.allow_negative_prompt = False
|
||||
|
||||
def refresh(self):
|
||||
pass
|
||||
|
||||
def link_preview(self, filename):
|
||||
return "./sd_extra_networks/thumb?filename=" + urllib.parse.quote(filename.replace('\\', '/')) + "&mtime=" + str(os.path.getmtime(filename))
|
||||
|
||||
def search_terms_from_path(self, filename, possible_directories=None):
|
||||
abspath = os.path.abspath(filename)
|
||||
|
||||
for parentdir in (possible_directories if possible_directories is not None else self.allowed_directories_for_previews()):
|
||||
parentdir = os.path.abspath(parentdir)
|
||||
if abspath.startswith(parentdir):
|
||||
return abspath[len(parentdir):].replace('\\', '/')
|
||||
|
||||
return ""
|
||||
|
||||
def create_html(self, tabname):
|
||||
view = shared.opts.extra_networks_default_view
|
||||
items_html = ''
|
||||
|
||||
subdirs = {}
|
||||
for parentdir in [os.path.abspath(x) for x in self.allowed_directories_for_previews()]:
|
||||
for x in glob.glob(os.path.join(parentdir, '**/*'), recursive=True):
|
||||
if not os.path.isdir(x):
|
||||
continue
|
||||
|
||||
subdir = os.path.abspath(x)[len(parentdir):].replace("\\", "/")
|
||||
while subdir.startswith("/"):
|
||||
subdir = subdir[1:]
|
||||
|
||||
subdirs[subdir] = 1
|
||||
|
||||
if subdirs:
|
||||
subdirs = {"": 1, **subdirs}
|
||||
|
||||
subdirs_html = "".join([f"""
|
||||
<button class='gr-button gr-button-lg gr-button-secondary{" search-all" if subdir=="" else ""}' onclick='extraNetworksSearchButton("{tabname}_extra_tabs", event)'>
|
||||
{html.escape(subdir if subdir!="" else "all")}
|
||||
</button>
|
||||
""" for subdir in subdirs])
|
||||
|
||||
for item in self.list_items():
|
||||
items_html += self.create_html_for_item(item, tabname)
|
||||
|
||||
if items_html == '':
|
||||
dirs = "".join([f"<li>{x}</li>" for x in self.allowed_directories_for_previews()])
|
||||
items_html = shared.html("extra-networks-no-cards.html").format(dirs=dirs)
|
||||
|
||||
res = f"""
|
||||
<div id='{tabname}_{self.name}_subdirs' class='extra-network-subdirs extra-network-subdirs-{view}'>
|
||||
{subdirs_html}
|
||||
</div>
|
||||
<div id='{tabname}_{self.name}_cards' class='extra-network-{view}'>
|
||||
{items_html}
|
||||
</div>
|
||||
"""
|
||||
|
||||
return res
|
||||
|
||||
def list_items(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def allowed_directories_for_previews(self):
|
||||
return []
|
||||
|
||||
def create_html_for_item(self, item, tabname):
|
||||
preview = item.get("preview", None)
|
||||
|
||||
onclick = item.get("onclick", None)
|
||||
if onclick is None:
|
||||
onclick = '"' + html.escape(f"""return cardClicked({json.dumps(tabname)}, {item["prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"'
|
||||
|
||||
args = {
|
||||
"preview_html": "style='background-image: url(\"" + html.escape(preview) + "\")'" if preview else '',
|
||||
"prompt": item.get("prompt", None),
|
||||
"tabname": json.dumps(tabname),
|
||||
"local_preview": json.dumps(item["local_preview"]),
|
||||
"name": item["name"],
|
||||
"card_clicked": onclick,
|
||||
"save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {json.dumps(tabname)}, {json.dumps(item["local_preview"])})""") + '"',
|
||||
"search_term": item.get("search_term", ""),
|
||||
}
|
||||
|
||||
return self.card_page.format(**args)
|
||||
|
||||
|
||||
def intialize():
|
||||
extra_pages.clear()
|
||||
|
||||
|
||||
class ExtraNetworksUi:
|
||||
def __init__(self):
|
||||
self.pages = None
|
||||
self.stored_extra_pages = None
|
||||
|
||||
self.button_save_preview = None
|
||||
self.preview_target_filename = None
|
||||
|
||||
self.tabname = None
|
||||
|
||||
|
||||
def pages_in_preferred_order(pages):
|
||||
tab_order = [x.lower().strip() for x in shared.opts.ui_extra_networks_tab_reorder.split(",")]
|
||||
|
||||
def tab_name_score(name):
|
||||
name = name.lower()
|
||||
for i, possible_match in enumerate(tab_order):
|
||||
if possible_match in name:
|
||||
return i
|
||||
|
||||
return len(pages)
|
||||
|
||||
tab_scores = {page.name: (tab_name_score(page.name), original_index) for original_index, page in enumerate(pages)}
|
||||
|
||||
return sorted(pages, key=lambda x: tab_scores[x.name])
|
||||
|
||||
|
||||
def create_ui(container, button, tabname):
|
||||
ui = ExtraNetworksUi()
|
||||
ui.pages = []
|
||||
ui.stored_extra_pages = pages_in_preferred_order(extra_pages.copy())
|
||||
ui.tabname = tabname
|
||||
|
||||
with gr.Tabs(elem_id=tabname+"_extra_tabs") as tabs:
|
||||
for page in ui.stored_extra_pages:
|
||||
with gr.Tab(page.title):
|
||||
page_elem = gr.HTML(page.create_html(ui.tabname))
|
||||
ui.pages.append(page_elem)
|
||||
|
||||
filter = gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", placeholder="Search...", visible=False)
|
||||
button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh")
|
||||
button_close = gr.Button('Close', elem_id=tabname+"_extra_close")
|
||||
|
||||
ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False)
|
||||
ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=tabname+"_preview_filename", visible=False)
|
||||
|
||||
def toggle_visibility(is_visible):
|
||||
is_visible = not is_visible
|
||||
return is_visible, gr.update(visible=is_visible)
|
||||
|
||||
state_visible = gr.State(value=False)
|
||||
button.click(fn=toggle_visibility, inputs=[state_visible], outputs=[state_visible, container])
|
||||
button_close.click(fn=toggle_visibility, inputs=[state_visible], outputs=[state_visible, container])
|
||||
|
||||
def refresh():
|
||||
res = []
|
||||
|
||||
for pg in ui.stored_extra_pages:
|
||||
pg.refresh()
|
||||
res.append(pg.create_html(ui.tabname))
|
||||
|
||||
return res
|
||||
|
||||
button_refresh.click(fn=refresh, inputs=[], outputs=ui.pages)
|
||||
|
||||
return ui
|
||||
|
||||
|
||||
def path_is_parent(parent_path, child_path):
|
||||
parent_path = os.path.abspath(parent_path)
|
||||
child_path = os.path.abspath(child_path)
|
||||
|
||||
return child_path.startswith(parent_path)
|
||||
|
||||
|
||||
def setup_ui(ui, gallery):
|
||||
def save_preview(index, images, filename):
|
||||
if len(images) == 0:
|
||||
print("There is no image in gallery to save as a preview.")
|
||||
return [page.create_html(ui.tabname) for page in ui.stored_extra_pages]
|
||||
|
||||
index = int(index)
|
||||
index = 0 if index < 0 else index
|
||||
index = len(images) - 1 if index >= len(images) else index
|
||||
|
||||
img_info = images[index if index >= 0 else 0]
|
||||
image = image_from_url_text(img_info)
|
||||
|
||||
is_allowed = False
|
||||
for extra_page in ui.stored_extra_pages:
|
||||
if any([path_is_parent(x, filename) for x in extra_page.allowed_directories_for_previews()]):
|
||||
is_allowed = True
|
||||
break
|
||||
|
||||
assert is_allowed, f'writing to {filename} is not allowed'
|
||||
|
||||
image.save(filename)
|
||||
|
||||
return [page.create_html(ui.tabname) for page in ui.stored_extra_pages]
|
||||
|
||||
ui.button_save_preview.click(
|
||||
fn=save_preview,
|
||||
_js="function(x, y, z){return [selected_gallery_index(), y, z]}",
|
||||
inputs=[ui.preview_target_filename, gallery, ui.preview_target_filename],
|
||||
outputs=[*ui.pages]
|
||||
)
|
||||
|
|
@ -1,39 +0,0 @@
|
|||
import html
|
||||
import json
|
||||
import os
|
||||
import urllib.parse
|
||||
|
||||
from modules import shared, ui_extra_networks, sd_models
|
||||
|
||||
|
||||
class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
|
||||
def __init__(self):
|
||||
super().__init__('Checkpoints')
|
||||
|
||||
def refresh(self):
|
||||
shared.refresh_checkpoints()
|
||||
|
||||
def list_items(self):
|
||||
checkpoint: sd_models.CheckpointInfo
|
||||
for name, checkpoint in sd_models.checkpoints_list.items():
|
||||
path, ext = os.path.splitext(checkpoint.filename)
|
||||
previews = [path + ".png", path + ".preview.png"]
|
||||
|
||||
preview = None
|
||||
for file in previews:
|
||||
if os.path.isfile(file):
|
||||
preview = self.link_preview(file)
|
||||
break
|
||||
|
||||
yield {
|
||||
"name": checkpoint.name_for_extra,
|
||||
"filename": path,
|
||||
"preview": preview,
|
||||
"search_term": self.search_terms_from_path(checkpoint.filename) + " " + (checkpoint.sha256 or ""),
|
||||
"onclick": '"' + html.escape(f"""return selectCheckpoint({json.dumps(name)})""") + '"',
|
||||
"local_preview": path + ".png",
|
||||
}
|
||||
|
||||
def allowed_directories_for_previews(self):
|
||||
return [v for v in [shared.cmd_opts.ckpt_dir, sd_models.model_path] if v is not None]
|
||||
|
|
@ -1,36 +0,0 @@
|
|||
import json
|
||||
import os
|
||||
|
||||
from modules import shared, ui_extra_networks
|
||||
|
||||
|
||||
class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage):
|
||||
def __init__(self):
|
||||
super().__init__('Hypernetworks')
|
||||
|
||||
def refresh(self):
|
||||
shared.reload_hypernetworks()
|
||||
|
||||
def list_items(self):
|
||||
for name, path in shared.hypernetworks.items():
|
||||
path, ext = os.path.splitext(path)
|
||||
previews = [path + ".png", path + ".preview.png"]
|
||||
|
||||
preview = None
|
||||
for file in previews:
|
||||
if os.path.isfile(file):
|
||||
preview = self.link_preview(file)
|
||||
break
|
||||
|
||||
yield {
|
||||
"name": name,
|
||||
"filename": path,
|
||||
"preview": preview,
|
||||
"search_term": self.search_terms_from_path(path),
|
||||
"prompt": json.dumps(f"<hypernet:{name}:") + " + opts.extra_networks_default_multiplier + " + json.dumps(">"),
|
||||
"local_preview": path + ".png",
|
||||
}
|
||||
|
||||
def allowed_directories_for_previews(self):
|
||||
return [shared.cmd_opts.hypernetwork_dir]
|
||||
|
|
@ -1,34 +0,0 @@
|
|||
import json
|
||||
import os
|
||||
|
||||
from modules import ui_extra_networks, sd_hijack
|
||||
|
||||
|
||||
class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage):
|
||||
def __init__(self):
|
||||
super().__init__('Textual Inversion')
|
||||
self.allow_negative_prompt = True
|
||||
|
||||
def refresh(self):
|
||||
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True)
|
||||
|
||||
def list_items(self):
|
||||
for embedding in sd_hijack.model_hijack.embedding_db.word_embeddings.values():
|
||||
path, ext = os.path.splitext(embedding.filename)
|
||||
preview_file = path + ".preview.png"
|
||||
|
||||
preview = None
|
||||
if os.path.isfile(preview_file):
|
||||
preview = self.link_preview(preview_file)
|
||||
|
||||
yield {
|
||||
"name": embedding.name,
|
||||
"filename": embedding.filename,
|
||||
"preview": preview,
|
||||
"search_term": self.search_terms_from_path(embedding.filename),
|
||||
"prompt": json.dumps(embedding.name),
|
||||
"local_preview": path + ".preview.png",
|
||||
}
|
||||
|
||||
def allowed_directories_for_previews(self):
|
||||
return list(sd_hijack.model_hijack.embedding_db.embedding_dirs)
|
|
@ -1,57 +0,0 @@
|
|||
import gradio as gr
|
||||
from modules import scripts_postprocessing, scripts, shared, gfpgan_model, codeformer_model, ui_common, postprocessing, call_queue
|
||||
import modules.generation_parameters_copypaste as parameters_copypaste
|
||||
|
||||
|
||||
def create_ui():
|
||||
tab_index = gr.State(value=0)
|
||||
|
||||
with gr.Row().style(equal_height=False, variant='compact'):
|
||||
with gr.Column(variant='compact'):
|
||||
with gr.Tabs(elem_id="mode_extras"):
|
||||
with gr.TabItem('Single Image', elem_id="extras_single_tab") as tab_single:
|
||||
extras_image = gr.Image(label="Source", source="upload", interactive=True, type="pil", elem_id="extras_image")
|
||||
|
||||
with gr.TabItem('Batch Process', elem_id="extras_batch_process_tab") as tab_batch:
|
||||
image_batch = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file", elem_id="extras_image_batch")
|
||||
|
||||
with gr.TabItem('Batch from Directory', elem_id="extras_batch_directory_tab") as tab_batch_dir:
|
||||
extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, placeholder="A directory on the same machine where the server is running.", elem_id="extras_batch_input_dir")
|
||||
extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.", elem_id="extras_batch_output_dir")
|
||||
show_extras_results = gr.Checkbox(label='Show result images', value=True, elem_id="extras_show_extras_results")
|
||||
|
||||
submit = gr.Button('Generate', elem_id="extras_generate", variant='primary')
|
||||
|
||||
script_inputs = scripts.scripts_postproc.setup_ui()
|
||||
|
||||
with gr.Column():
|
||||
result_images, html_info_x, html_info, html_log = ui_common.create_output_panel("extras", shared.opts.outdir_extras_samples)
|
||||
|
||||
tab_single.select(fn=lambda: 0, inputs=[], outputs=[tab_index])
|
||||
tab_batch.select(fn=lambda: 1, inputs=[], outputs=[tab_index])
|
||||
tab_batch_dir.select(fn=lambda: 2, inputs=[], outputs=[tab_index])
|
||||
|
||||
submit.click(
|
||||
fn=call_queue.wrap_gradio_gpu_call(postprocessing.run_postprocessing, extra_outputs=[None, '']),
|
||||
inputs=[
|
||||
tab_index,
|
||||
extras_image,
|
||||
image_batch,
|
||||
extras_batch_input_dir,
|
||||
extras_batch_output_dir,
|
||||
show_extras_results,
|
||||
*script_inputs
|
||||
],
|
||||
outputs=[
|
||||
result_images,
|
||||
html_info_x,
|
||||
html_info,
|
||||
]
|
||||
)
|
||||
|
||||
parameters_copypaste.add_paste_fields("extras", extras_image, None)
|
||||
|
||||
extras_image.change(
|
||||
fn=scripts.scripts_postproc.image_changed,
|
||||
inputs=[], outputs=[]
|
||||
)
|
|
@ -11,6 +11,7 @@ from modules import modelloader, shared
|
|||
|
||||
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
|
||||
NEAREST = (Image.Resampling.NEAREST if hasattr(Image, 'Resampling') else Image.NEAREST)
|
||||
from modules.paths import models_path
|
||||
|
||||
|
||||
class Upscaler:
|
||||
|
@ -38,7 +39,7 @@ class Upscaler:
|
|||
self.mod_scale = None
|
||||
|
||||
if self.model_path is None and self.name:
|
||||
self.model_path = os.path.join(shared.models_path, self.name)
|
||||
self.model_path = os.path.join(models_path, self.name)
|
||||
if self.model_path and create_dirs:
|
||||
os.makedirs(self.model_path, exist_ok=True)
|
||||
|
||||
|
@ -94,7 +95,6 @@ class UpscalerData:
|
|||
def __init__(self, name: str, path: str, upscaler: Upscaler = None, scale: int = 4, model=None):
|
||||
self.name = name
|
||||
self.data_path = path
|
||||
self.local_data_path = path
|
||||
self.scaler = upscaler
|
||||
self.scale = scale
|
||||
self.model = model
|
||||
|
@ -142,4 +142,4 @@ class UpscalerNearest(Upscaler):
|
|||
def __init__(self, dirname=None):
|
||||
super().__init__(False)
|
||||
self.name = "Nearest"
|
||||
self.scalers = [UpscalerData("Nearest", None, self)]
|
||||
self.scalers = [UpscalerData("Nearest", None, self)]
|
|
@ -1,10 +1,11 @@
|
|||
blendmodes
|
||||
accelerate
|
||||
basicsr
|
||||
fairscale==0.4.4
|
||||
fonts
|
||||
font-roboto
|
||||
gfpgan
|
||||
gradio==3.16.2
|
||||
gradio==3.15.0
|
||||
invisible-watermark
|
||||
numpy
|
||||
omegaconf
|
||||
|
@ -16,7 +17,7 @@ pytorch_lightning==1.7.7
|
|||
realesrgan
|
||||
scikit-image>=0.19
|
||||
timm==0.4.12
|
||||
transformers==4.25.1
|
||||
transformers==4.19.2
|
||||
torch
|
||||
einops
|
||||
jsonmerge
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
blendmodes==2022
|
||||
transformers==4.25.1
|
||||
transformers==4.19.2
|
||||
accelerate==0.12.0
|
||||
basicsr==1.4.2
|
||||
gfpgan==1.3.8
|
||||
gradio==3.16.2
|
||||
gradio==3.15.0
|
||||
numpy==1.23.3
|
||||
Pillow==9.4.0
|
||||
realesrgan==0.3.0
|
||||
|
@ -14,6 +14,7 @@ scikit-image==0.19.2
|
|||
fonts
|
||||
font-roboto
|
||||
timm==0.6.7
|
||||
fairscale==0.4.9
|
||||
piexif==1.1.3
|
||||
einops==0.4.1
|
||||
jsonmerge==1.8.0
|
||||
|
|
13
script.js
13
script.js
|
@ -13,7 +13,6 @@ function get_uiCurrentTabContent() {
|
|||
}
|
||||
|
||||
uiUpdateCallbacks = []
|
||||
uiLoadedCallbacks = []
|
||||
uiTabChangeCallbacks = []
|
||||
optionsChangedCallbacks = []
|
||||
let uiCurrentTab = null
|
||||
|
@ -21,9 +20,6 @@ let uiCurrentTab = null
|
|||
function onUiUpdate(callback){
|
||||
uiUpdateCallbacks.push(callback)
|
||||
}
|
||||
function onUiLoaded(callback){
|
||||
uiLoadedCallbacks.push(callback)
|
||||
}
|
||||
function onUiTabChange(callback){
|
||||
uiTabChangeCallbacks.push(callback)
|
||||
}
|
||||
|
@ -42,15 +38,8 @@ function executeCallbacks(queue, m) {
|
|||
queue.forEach(function(x){runCallback(x, m)})
|
||||
}
|
||||
|
||||
var executedOnLoaded = false;
|
||||
|
||||
document.addEventListener("DOMContentLoaded", function() {
|
||||
var mutationObserver = new MutationObserver(function(m){
|
||||
if(!executedOnLoaded && gradioApp().querySelector('#txt2img_prompt')){
|
||||
executedOnLoaded = true;
|
||||
executeCallbacks(uiLoadedCallbacks);
|
||||
}
|
||||
|
||||
executeCallbacks(uiUpdateCallbacks, m);
|
||||
const newTab = get_uiCurrentTab();
|
||||
if ( newTab && ( newTab !== uiCurrentTab ) ) {
|
||||
|
@ -64,7 +53,7 @@ document.addEventListener("DOMContentLoaded", function() {
|
|||
/**
|
||||
* Add a ctrl+enter as a shortcut to start a generation
|
||||
*/
|
||||
document.addEventListener('keydown', function(e) {
|
||||
document.addEventListener('keydown', function(e) {
|
||||
var handled = false;
|
||||
if (e.key !== undefined) {
|
||||
if((e.key == "Enter" && (e.metaKey || e.ctrlKey || e.altKey))) handled = true;
|
||||
|
|
|
@ -6,7 +6,7 @@ from tqdm import trange
|
|||
import modules.scripts as scripts
|
||||
import gradio as gr
|
||||
|
||||
from modules import processing, shared, sd_samplers, prompt_parser, sd_samplers_common
|
||||
from modules import processing, shared, sd_samplers, prompt_parser
|
||||
from modules.processing import Processed
|
||||
from modules.shared import opts, cmd_opts, state
|
||||
|
||||
|
@ -50,7 +50,7 @@ def find_noise_for_image(p, cond, uncond, cfg_scale, steps):
|
|||
|
||||
x = x + d * dt
|
||||
|
||||
sd_samplers_common.store_latent(x)
|
||||
sd_samplers.store_latent(x)
|
||||
|
||||
# This shouldn't be necessary, but solved some VRAM issues
|
||||
del x_in, sigma_in, cond_in, c_out, c_in, t,
|
||||
|
@ -104,7 +104,7 @@ def find_noise_for_image_sigma_adjustment(p, cond, uncond, cfg_scale, steps):
|
|||
dt = sigmas[i] - sigmas[i - 1]
|
||||
x = x + d * dt
|
||||
|
||||
sd_samplers_common.store_latent(x)
|
||||
sd_samplers.store_latent(x)
|
||||
|
||||
# This shouldn't be necessary, but solved some VRAM issues
|
||||
del x_in, sigma_in, cond_in, c_out, c_in, t,
|
||||
|
|
|
@ -1,36 +0,0 @@
|
|||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
from modules import scripts_postprocessing, codeformer_model
|
||||
import gradio as gr
|
||||
|
||||
from modules.ui_components import FormRow
|
||||
|
||||
|
||||
class ScriptPostprocessingCodeFormer(scripts_postprocessing.ScriptPostprocessing):
|
||||
name = "CodeFormer"
|
||||
order = 3000
|
||||
|
||||
def ui(self):
|
||||
with FormRow():
|
||||
codeformer_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer visibility", value=0, elem_id="extras_codeformer_visibility")
|
||||
codeformer_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer weight (0 = maximum effect, 1 = minimum effect)", value=0, elem_id="extras_codeformer_weight")
|
||||
|
||||
return {
|
||||
"codeformer_visibility": codeformer_visibility,
|
||||
"codeformer_weight": codeformer_weight,
|
||||
}
|
||||
|
||||
def process(self, pp: scripts_postprocessing.PostprocessedImage, codeformer_visibility, codeformer_weight):
|
||||
if codeformer_visibility == 0:
|
||||
return
|
||||
|
||||
restored_img = codeformer_model.codeformer.restore(np.array(pp.image, dtype=np.uint8), w=codeformer_weight)
|
||||
res = Image.fromarray(restored_img)
|
||||
|
||||
if codeformer_visibility < 1.0:
|
||||
res = Image.blend(pp.image, res, codeformer_visibility)
|
||||
|
||||
pp.image = res
|
||||
pp.info["CodeFormer visibility"] = round(codeformer_visibility, 3)
|
||||
pp.info["CodeFormer weight"] = round(codeformer_weight, 3)
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user