Compare commits
180 Commits
v1.0.0-pre
...
master
Author | SHA1 | Date | |
---|---|---|---|
db4cac5d1f | |||
|
ea9bd9fc74 | ||
|
0ca1a64cfc | ||
|
3993aa43e9 | ||
|
27a50d4b38 | ||
|
475095f50a | ||
|
668d7e9b9a | ||
|
5a1b62e9f8 | ||
|
88a46e8427 | ||
|
6524478850 | ||
|
3e0f9a7543 | ||
|
40e51fd6ef | ||
|
21593c8082 | ||
|
c0e0b5844d | ||
|
dca632ab90 | ||
|
81823407d9 | ||
|
30228c67ca | ||
|
c4b9ed1a27 | ||
|
72dd5785d9 | ||
|
4306659c4d | ||
|
127bfb6c41 | ||
|
ba6a4e7e94 | ||
|
c27c0de0f7 | ||
|
6c6c6636bb | ||
|
982295aee5 | ||
|
3b2ad20ac1 | ||
|
cf0cfefe91 | ||
|
269833067d | ||
|
fb97acef63 | ||
![]() |
92bae77b88 | ||
|
1b8af15f13 | ||
|
226d840e84 | ||
|
07edf57409 | ||
|
fa4fe45403 | ||
|
814600f298 | ||
|
30a64504b1 | ||
|
b1873dbb77 | ||
|
2217331cd1 | ||
|
7738c057ce | ||
|
0426b34789 | ||
|
bfe7e7f15f | ||
|
2c1bb46c7a | ||
|
19de2a626b | ||
|
ee9fdf7f62 | ||
|
aa4688eb83 | ||
|
ab059b6e48 | ||
|
040ec7a80e | ||
|
4df63d2d19 | ||
|
274474105a | ||
|
95916e3777 | ||
|
2db8ed32cd | ||
|
f4d0538bf2 | ||
|
aa54a9d416 | ||
|
f8fcad502e | ||
|
58ae93b954 | ||
|
6e78f6a896 | ||
|
5feae71dd2 | ||
|
449531a6c5 | ||
|
9b8ed7f8ec | ||
|
9118b08606 | ||
|
0c7c36a6c6 | ||
|
cbd6329488 | ||
|
c81b52ffbd | ||
|
847ceae1f7 | ||
|
399720dac2 | ||
|
f91068f426 | ||
|
938578e8a9 | ||
|
1e2b10d2dc | ||
|
5997457fd4 | ||
|
edabd92729 | ||
|
c46f3ad98b | ||
|
7c53f81caf | ||
|
00dab8f10d | ||
|
aa6e55e001 | ||
|
920fe8057c | ||
|
8d7382ab24 | ||
|
e8efd2ec47 | ||
|
659d602dce | ||
|
f6b7768f84 | ||
|
1d24665229 | ||
|
09a142a05a | ||
|
fb58fa6240 | ||
|
0a8515085e | ||
|
1d8e06d542 | ||
|
91c8d0dcfc | ||
|
fecb990deb | ||
|
41e76d1209 | ||
|
29d2d6a094 | ||
|
e2c71a4bd4 | ||
|
1e22f48f4d | ||
|
f4eeff659e | ||
|
591b68e56c | ||
|
cd7e8fb42b | ||
|
b7d2af8c7f | ||
|
1421e95960 | ||
|
5d14f282c2 | ||
|
f8feeaaedb | ||
|
d04e3e921e | ||
|
4aa7f5b5b9 | ||
|
f9edd578e9 | ||
|
02b8b957d7 | ||
|
ada17dbd7c | ||
|
e8a41df49f | ||
|
bea31e849a | ||
|
60061eb8d4 | ||
|
bd52a6d899 | ||
|
3752aad23d | ||
|
7d1f2a3a49 | ||
|
28c4c9b907 | ||
|
ce72af87d3 | ||
|
0834d4ce37 | ||
|
c99d705e57 | ||
|
38d83665d9 | ||
|
4c52dfe4ac | ||
|
41975c375c | ||
|
8ce0ccf336 | ||
|
2aac1d9778 | ||
|
6b82efd737 | ||
|
cc8c9b7474 | ||
|
32d389ef0f | ||
|
a6a5bfb155 | ||
|
eafaf14167 | ||
|
23a9d5e273 | ||
|
6b3981c068 | ||
|
14c0884fd0 | ||
|
5eee2ac398 | ||
|
56c83e453a | ||
|
9ecf1e827c | ||
|
63391419c1 | ||
|
9beb794e0b | ||
|
6f31d2210c | ||
|
d2ac95fa7b | ||
|
a43fafb481 | ||
|
7a14c8ab45 | ||
|
cdc2fa209a | ||
|
c4b9b07db6 | ||
|
645f4e7ef8 | ||
|
9e72dc7434 | ||
|
f90798c6b6 | ||
|
f4ec411f2c | ||
|
1619233a74 | ||
|
10421f93c3 | ||
|
4d634dc592 | ||
|
e57b5f7c55 | ||
|
d82d471bf7 | ||
|
6cff440182 | ||
|
d1d6ce2983 | ||
|
3cead6983e | ||
|
a85e22a127 | ||
|
e0df864b8c | ||
|
f5d73b6a66 | ||
|
0cc5f380d5 | ||
|
2de99d62dd | ||
|
dc0f05c57c | ||
|
57096823fa | ||
|
15e89ef0f6 | ||
|
2d92d05ca2 | ||
|
e425b9812b | ||
|
789d47f832 | ||
|
e179b6098a | ||
|
635499e832 | ||
|
1574e96729 | ||
|
1982ef6890 | ||
|
57c1baa774 | ||
|
23dafe6d86 | ||
|
11485659dc | ||
|
bd9b55ee90 | ||
|
ee0a0da324 | ||
|
d5ce044bcd | ||
|
1bfec873fa | ||
|
e3b53fd295 | ||
|
84d9ce30cb | ||
|
48a15821de | ||
|
bef1931895 | ||
|
ec8774729e | ||
|
e46bfa5a9e | ||
|
9fc354e130 | ||
|
d30ac02f28 | ||
|
f64af77adc | ||
|
82a28bfe35 |
29
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
29
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
|
@ -37,20 +37,20 @@ body:
|
|||
id: what-should
|
||||
attributes:
|
||||
label: What should have happened?
|
||||
description: tell what you think the normal behavior should be
|
||||
description: Tell what you think the normal behavior should be
|
||||
validations:
|
||||
required: true
|
||||
- type: input
|
||||
id: commit
|
||||
attributes:
|
||||
label: Commit where the problem happens
|
||||
description: Which commit are you running ? (Do not write *Latest version/repo/commit*, as this means nothing and will have changed by the time we read your issue. Rather, copy the **Commit hash** shown in the cmd/terminal when you launch the UI)
|
||||
description: Which commit are you running ? (Do not write *Latest version/repo/commit*, as this means nothing and will have changed by the time we read your issue. Rather, copy the **Commit** link at the bottom of the UI, or from the cmd/terminal if you can't launch it.)
|
||||
validations:
|
||||
required: true
|
||||
- type: dropdown
|
||||
id: platforms
|
||||
attributes:
|
||||
label: What platforms do you use to access UI ?
|
||||
label: What platforms do you use to access the UI ?
|
||||
multiple: true
|
||||
options:
|
||||
- Windows
|
||||
|
@ -74,10 +74,27 @@ body:
|
|||
id: cmdargs
|
||||
attributes:
|
||||
label: Command Line Arguments
|
||||
description: Are you using any launching parameters/command line arguments (modified webui-user.py) ? If yes, please write them below
|
||||
description: Are you using any launching parameters/command line arguments (modified webui-user .bat/.sh) ? If yes, please write them below. Write "No" otherwise.
|
||||
render: Shell
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
id: extensions
|
||||
attributes:
|
||||
label: List of extensions
|
||||
description: Are you using any extensions other than built-ins? If yes, provide a list, you can copy it at "Extensions" tab. Write "No" otherwise.
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
id: logs
|
||||
attributes:
|
||||
label: Console logs
|
||||
description: Please provide **full** cmd/terminal logs from the moment you started UI to the end of it, after your bug happened. If it's very long, provide a link to pastebin or similar service.
|
||||
render: Shell
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
id: misc
|
||||
attributes:
|
||||
label: Additional information, context and logs
|
||||
description: Please provide us with any relevant additional info, context or log output.
|
||||
label: Additional information
|
||||
description: Please provide us with any relevant additional info or context.
|
||||
|
|
|
@ -17,7 +17,7 @@ A browser interface based on Gradio library for Stable Diffusion.
|
|||
- a man in a (tuxedo:1.21) - alternative syntax
|
||||
- select text and press ctrl+up or ctrl+down to automatically adjust attention to selected text (code contributed by anonymous user)
|
||||
- Loopback, run img2img processing multiple times
|
||||
- X/Y plot, a way to draw a 2 dimensional plot of images with different parameters
|
||||
- X/Y/Z plot, a way to draw a 3 dimensional plot of images with different parameters
|
||||
- Textual Inversion
|
||||
- have as many embeddings as you want and use any names you like for them
|
||||
- use multiple embeddings with different numbers of vectors per token
|
||||
|
@ -155,6 +155,8 @@ Licenses for borrowed code can be found in `Settings -> Licenses` screen, and al
|
|||
- Idea for Composable Diffusion - https://github.com/energy-based-model/Compositional-Visual-Generation-with-Composable-Diffusion-Models-PyTorch
|
||||
- xformers - https://github.com/facebookresearch/xformers
|
||||
- DeepDanbooru - interrogator for anime diffusers https://github.com/KichangKim/DeepDanbooru
|
||||
- Sampling in float32 precision from a float16 UNet - marunine for the idea, Birch-san for the example Diffusers implementation (https://github.com/Birch-san/diffusers-play/tree/92feee6)
|
||||
- Instruct pix2pix - Tim Brooks (star), Aleksander Holynski (star), Alexei A. Efros (no star) - https://github.com/timothybrooks/instruct-pix2pix
|
||||
- Security advice - RyotaK
|
||||
- Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user.
|
||||
- (You)
|
||||
|
|
98
configs/instruct-pix2pix.yaml
Normal file
98
configs/instruct-pix2pix.yaml
Normal file
|
@ -0,0 +1,98 @@
|
|||
# File modified by authors of InstructPix2Pix from original (https://github.com/CompVis/stable-diffusion).
|
||||
# See more details in LICENSE.
|
||||
|
||||
model:
|
||||
base_learning_rate: 1.0e-04
|
||||
target: modules.models.diffusion.ddpm_edit.LatentDiffusion
|
||||
params:
|
||||
linear_start: 0.00085
|
||||
linear_end: 0.0120
|
||||
num_timesteps_cond: 1
|
||||
log_every_t: 200
|
||||
timesteps: 1000
|
||||
first_stage_key: edited
|
||||
cond_stage_key: edit
|
||||
# image_size: 64
|
||||
# image_size: 32
|
||||
image_size: 16
|
||||
channels: 4
|
||||
cond_stage_trainable: false # Note: different from the one we trained before
|
||||
conditioning_key: hybrid
|
||||
monitor: val/loss_simple_ema
|
||||
scale_factor: 0.18215
|
||||
use_ema: false
|
||||
|
||||
scheduler_config: # 10000 warmup steps
|
||||
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||
params:
|
||||
warm_up_steps: [ 0 ]
|
||||
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||
f_start: [ 1.e-6 ]
|
||||
f_max: [ 1. ]
|
||||
f_min: [ 1. ]
|
||||
|
||||
unet_config:
|
||||
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||
params:
|
||||
image_size: 32 # unused
|
||||
in_channels: 8
|
||||
out_channels: 4
|
||||
model_channels: 320
|
||||
attention_resolutions: [ 4, 2, 1 ]
|
||||
num_res_blocks: 2
|
||||
channel_mult: [ 1, 2, 4, 4 ]
|
||||
num_heads: 8
|
||||
use_spatial_transformer: True
|
||||
transformer_depth: 1
|
||||
context_dim: 768
|
||||
use_checkpoint: True
|
||||
legacy: False
|
||||
|
||||
first_stage_config:
|
||||
target: ldm.models.autoencoder.AutoencoderKL
|
||||
params:
|
||||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
ddconfig:
|
||||
double_z: true
|
||||
z_channels: 4
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult:
|
||||
- 1
|
||||
- 2
|
||||
- 4
|
||||
- 4
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: []
|
||||
dropout: 0.0
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
|
||||
cond_stage_config:
|
||||
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||
|
||||
data:
|
||||
target: main.DataModuleFromConfig
|
||||
params:
|
||||
batch_size: 128
|
||||
num_workers: 1
|
||||
wrap: false
|
||||
validation:
|
||||
target: edit_dataset.EditDataset
|
||||
params:
|
||||
path: data/clip-filtered-dataset
|
||||
cache_dir: data/
|
||||
cache_name: data_10k
|
||||
split: val
|
||||
min_text_sim: 0.2
|
||||
min_image_sim: 0.75
|
||||
min_direction_sim: 0.2
|
||||
max_samples_per_prompt: 1
|
||||
min_resize_res: 512
|
||||
max_resize_res: 512
|
||||
crop_res: 512
|
||||
output_as_edit: False
|
||||
real_input: True
|
|
@ -1,8 +1,7 @@
|
|||
model:
|
||||
base_learning_rate: 1.0e-4
|
||||
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||
base_learning_rate: 7.5e-05
|
||||
target: ldm.models.diffusion.ddpm.LatentInpaintDiffusion
|
||||
params:
|
||||
parameterization: "v"
|
||||
linear_start: 0.00085
|
||||
linear_end: 0.0120
|
||||
num_timesteps_cond: 1
|
||||
|
@ -12,29 +11,36 @@ model:
|
|||
cond_stage_key: "txt"
|
||||
image_size: 64
|
||||
channels: 4
|
||||
cond_stage_trainable: false
|
||||
conditioning_key: crossattn
|
||||
cond_stage_trainable: false # Note: different from the one we trained before
|
||||
conditioning_key: hybrid # important
|
||||
monitor: val/loss_simple_ema
|
||||
scale_factor: 0.18215
|
||||
use_ema: False # we set this to false because this is an inference only config
|
||||
finetune_keys: null
|
||||
|
||||
scheduler_config: # 10000 warmup steps
|
||||
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||
params:
|
||||
warm_up_steps: [ 2500 ] # NOTE for resuming. use 10000 if starting from scratch
|
||||
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||
f_start: [ 1.e-6 ]
|
||||
f_max: [ 1. ]
|
||||
f_min: [ 1. ]
|
||||
|
||||
unet_config:
|
||||
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||
params:
|
||||
use_checkpoint: True
|
||||
use_fp16: True
|
||||
image_size: 32 # unused
|
||||
in_channels: 4
|
||||
in_channels: 9 # 4 data + 4 downscaled image + 1 mask
|
||||
out_channels: 4
|
||||
model_channels: 320
|
||||
attention_resolutions: [ 4, 2, 1 ]
|
||||
num_res_blocks: 2
|
||||
channel_mult: [ 1, 2, 4, 4 ]
|
||||
num_head_channels: 64 # need to fix for flash-attn
|
||||
num_heads: 8
|
||||
use_spatial_transformer: True
|
||||
use_linear_in_transformer: True
|
||||
transformer_depth: 1
|
||||
context_dim: 1024
|
||||
context_dim: 768
|
||||
use_checkpoint: True
|
||||
legacy: False
|
||||
|
||||
first_stage_config:
|
||||
|
@ -43,7 +49,6 @@ model:
|
|||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
ddconfig:
|
||||
#attn_type: "vanilla-xformers"
|
||||
double_z: true
|
||||
z_channels: 4
|
||||
resolution: 256
|
||||
|
@ -62,7 +67,4 @@ model:
|
|||
target: torch.nn.Identity
|
||||
|
||||
cond_stage_config:
|
||||
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
|
||||
params:
|
||||
freeze: True
|
||||
layer: "penultimate"
|
||||
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
|
@ -1,4 +1,4 @@
|
|||
from modules import extra_networks
|
||||
from modules import extra_networks, shared
|
||||
import lora
|
||||
|
||||
class ExtraNetworkLora(extra_networks.ExtraNetwork):
|
||||
|
@ -6,6 +6,12 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork):
|
|||
super().__init__('lora')
|
||||
|
||||
def activate(self, p, params_list):
|
||||
additional = shared.opts.sd_lora
|
||||
|
||||
if additional != "" and additional in lora.available_loras and len([x for x in params_list if x.items[0] == additional]) == 0:
|
||||
p.all_prompts = [x + f"<lora:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts]
|
||||
params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
|
||||
|
||||
names = []
|
||||
multipliers = []
|
||||
for params in params_list:
|
||||
|
|
|
@ -166,7 +166,10 @@ def lora_forward(module, input, res):
|
|||
for lora in loaded_loras:
|
||||
module = lora.modules.get(lora_layer_name, None)
|
||||
if module is not None:
|
||||
res = res + module.up(module.down(input)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
|
||||
if shared.opts.lora_apply_to_outputs and res.shape == input.shape:
|
||||
res = res + module.up(module.down(res)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
|
||||
else:
|
||||
res = res + module.up(module.down(input)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
|
||||
|
||||
return res
|
||||
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
import torch
|
||||
import gradio as gr
|
||||
|
||||
import lora
|
||||
import extra_networks_lora
|
||||
import ui_extra_networks_lora
|
||||
from modules import script_callbacks, ui_extra_networks, extra_networks
|
||||
from modules import script_callbacks, ui_extra_networks, extra_networks, shared
|
||||
|
||||
|
||||
def unload():
|
||||
|
@ -28,3 +29,10 @@ torch.nn.Conv2d.forward = lora.lora_Conv2d_forward
|
|||
script_callbacks.on_model_loaded(lora.assign_lora_names_to_compvis_modules)
|
||||
script_callbacks.on_script_unloaded(unload)
|
||||
script_callbacks.on_before_ui(before_ui)
|
||||
|
||||
|
||||
shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), {
|
||||
"sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": [""] + [x for x in lora.available_loras]}, refresh=lora.list_available_loras),
|
||||
"lora_apply_to_outputs": shared.OptionInfo(False, "Apply Lora to outputs rather than inputs when possible (experimental)"),
|
||||
|
||||
}))
|
||||
|
|
|
@ -20,13 +20,14 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
|
|||
preview = None
|
||||
for file in previews:
|
||||
if os.path.isfile(file):
|
||||
preview = "./file=" + file.replace('\\', '/') + "?mtime=" + str(os.path.getmtime(file))
|
||||
preview = self.link_preview(file)
|
||||
break
|
||||
|
||||
yield {
|
||||
"name": name,
|
||||
"filename": path,
|
||||
"preview": preview,
|
||||
"search_term": self.search_terms_from_path(lora_on_disk.filename),
|
||||
"prompt": json.dumps(f"<lora:{name}:") + " + opts.extra_networks_default_multiplier + " + json.dumps(">"),
|
||||
"local_preview": path + ".png",
|
||||
}
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
<ul>
|
||||
<a href="#" title="replace preview image with currently selected in gallery" onclick={save_card_preview}>replace preview</a>
|
||||
</ul>
|
||||
<span style="display:none" class='search_term'>{search_term}</span>
|
||||
</div>
|
||||
<span class='name'>{name}</span>
|
||||
</div>
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
|
||||
function extensions_apply(_, _){
|
||||
disable = []
|
||||
update = []
|
||||
var disable = []
|
||||
var update = []
|
||||
|
||||
gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){
|
||||
if(x.name.startsWith("enable_") && ! x.checked)
|
||||
disable.push(x.name.substr(7))
|
||||
|
@ -16,11 +17,24 @@ function extensions_apply(_, _){
|
|||
}
|
||||
|
||||
function extensions_check(){
|
||||
var disable = []
|
||||
|
||||
gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){
|
||||
if(x.name.startsWith("enable_") && ! x.checked)
|
||||
disable.push(x.name.substr(7))
|
||||
})
|
||||
|
||||
gradioApp().querySelectorAll('#extensions .extension_status').forEach(function(x){
|
||||
x.innerHTML = "Loading..."
|
||||
})
|
||||
|
||||
return []
|
||||
|
||||
var id = randomId()
|
||||
requestProgress(id, gradioApp().getElementById('extensions_installed_top'), null, function(){
|
||||
|
||||
})
|
||||
|
||||
return [id, JSON.stringify(disable)]
|
||||
}
|
||||
|
||||
function install_extension_from_index(button, url){
|
||||
|
|
|
@ -16,7 +16,7 @@ function setupExtraNetworksForTab(tabname){
|
|||
searchTerm = search.value.toLowerCase()
|
||||
|
||||
gradioApp().querySelectorAll('#'+tabname+'_extra_tabs div.card').forEach(function(elem){
|
||||
text = elem.querySelector('.name').textContent.toLowerCase()
|
||||
text = elem.querySelector('.name').textContent.toLowerCase() + " " + elem.querySelector('.search_term').textContent.toLowerCase()
|
||||
elem.style.display = text.indexOf(searchTerm) == -1 ? "none" : ""
|
||||
})
|
||||
});
|
||||
|
@ -48,10 +48,39 @@ function setupExtraNetworks(){
|
|||
|
||||
onUiLoaded(setupExtraNetworks)
|
||||
|
||||
var re_extranet = /<([^:]+:[^:]+):[\d\.]+>/;
|
||||
var re_extranet_g = /\s+<([^:]+:[^:]+):[\d\.]+>/g;
|
||||
|
||||
function tryToRemoveExtraNetworkFromPrompt(textarea, text){
|
||||
var m = text.match(re_extranet)
|
||||
if(! m) return false
|
||||
|
||||
var partToSearch = m[1]
|
||||
var replaced = false
|
||||
var newTextareaText = textarea.value.replaceAll(re_extranet_g, function(found, index){
|
||||
m = found.match(re_extranet);
|
||||
if(m[1] == partToSearch){
|
||||
replaced = true;
|
||||
return ""
|
||||
}
|
||||
return found;
|
||||
})
|
||||
|
||||
if(replaced){
|
||||
textarea.value = newTextareaText
|
||||
return true;
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
function cardClicked(tabname, textToAdd, allowNegativePrompt){
|
||||
var textarea = allowNegativePrompt ? activePromptTextarea[tabname] : gradioApp().querySelector("#" + tabname + "_prompt > label > textarea")
|
||||
|
||||
textarea.value = textarea.value + " " + textToAdd
|
||||
if(! tryToRemoveExtraNetworkFromPrompt(textarea, textToAdd)){
|
||||
textarea.value = textarea.value + " " + textToAdd
|
||||
}
|
||||
|
||||
updateInput(textarea)
|
||||
}
|
||||
|
||||
|
@ -67,3 +96,12 @@ function saveCardPreview(event, tabname, filename){
|
|||
event.stopPropagation()
|
||||
event.preventDefault()
|
||||
}
|
||||
|
||||
function extraNetworksSearchButton(tabs_id, event){
|
||||
searchTextarea = gradioApp().querySelector("#" + tabs_id + ' > div > textarea')
|
||||
button = event.target
|
||||
text = button.classList.contains("search-all") ? "" : button.textContent.trim()
|
||||
|
||||
searchTextarea.value = text
|
||||
updateInput(searchTextarea)
|
||||
}
|
|
@ -17,7 +17,7 @@ titles = {
|
|||
"\u2199\ufe0f": "Read generation parameters from prompt or last generation if prompt is empty into user interface.",
|
||||
"\u{1f4c2}": "Open images output directory",
|
||||
"\u{1f4be}": "Save style",
|
||||
"\U0001F5D1": "Clear prompt",
|
||||
"\u{1f5d1}": "Clear prompt",
|
||||
"\u{1f4cb}": "Apply selected styles to current prompt",
|
||||
"\u{1f4d2}": "Paste available values into the field",
|
||||
"\u{1f3b4}": "Show extra networks",
|
||||
|
@ -50,7 +50,7 @@ titles = {
|
|||
|
||||
"None": "Do not do anything special",
|
||||
"Prompt matrix": "Separate prompts into parts using vertical pipe character (|) and the script will create a picture for every combination of them (except for the first part, which will be present in all combinations)",
|
||||
"X/Y plot": "Create a grid where images will have different parameters. Use inputs below to specify which parameters will be shared by columns and rows",
|
||||
"X/Y/Z plot": "Create grid(s) where images will have different parameters. Use inputs below to specify which parameters will be shared by columns and rows",
|
||||
"Custom code": "Run Python code. Advanced user only. Must run program with --allow-code for this to work",
|
||||
|
||||
"Prompt S/R": "Separate a list of words with commas, and the first word will be used as a keyword: script will search for this word in the prompt, and replace it with others",
|
||||
|
@ -66,8 +66,8 @@ titles = {
|
|||
|
||||
"Interrogate": "Reconstruct prompt from existing image and put it into the prompt field.",
|
||||
|
||||
"Images filename pattern": "Use following tags to define how filenames for images are chosen: [steps], [cfg], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [model_name], [prompt_words], [date], [datetime], [datetime<Format>], [datetime<Format><Time Zone>], [job_timestamp]; leave empty for default.",
|
||||
"Directory name pattern": "Use following tags to define how subdirectories for images and grids are chosen: [steps], [cfg], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [model_name], [prompt_words], [date], [datetime], [datetime<Format>], [datetime<Format><Time Zone>], [job_timestamp]; leave empty for default.",
|
||||
"Images filename pattern": "Use following tags to define how filenames for images are chosen: [steps], [cfg], [prompt_hash], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [model_name], [prompt_words], [date], [datetime], [datetime<Format>], [datetime<Format><Time Zone>], [job_timestamp]; leave empty for default.",
|
||||
"Directory name pattern": "Use following tags to define how subdirectories for images and grids are chosen: [steps], [cfg],[prompt_hash], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [model_name], [prompt_words], [date], [datetime], [datetime<Format>], [datetime<Format><Time Zone>], [job_timestamp]; leave empty for default.",
|
||||
"Max prompt words": "Set the maximum number of words to be used in the [prompt_words] option; ATTENTION: If the words are too long, they may exceed the maximum length of the file path that the system can handle",
|
||||
|
||||
"Loopback": "Process an image, use it as an input, repeat.",
|
||||
|
|
|
@ -191,6 +191,28 @@ function confirm_clear_prompt(prompt, negative_prompt) {
|
|||
return [prompt, negative_prompt]
|
||||
}
|
||||
|
||||
|
||||
promptTokecountUpdateFuncs = {}
|
||||
|
||||
function recalculatePromptTokens(name){
|
||||
if(promptTokecountUpdateFuncs[name]){
|
||||
promptTokecountUpdateFuncs[name]()
|
||||
}
|
||||
}
|
||||
|
||||
function recalculate_prompts_txt2img(){
|
||||
recalculatePromptTokens('txt2img_prompt')
|
||||
recalculatePromptTokens('txt2img_neg_prompt')
|
||||
return args_to_array(arguments);
|
||||
}
|
||||
|
||||
function recalculate_prompts_img2img(){
|
||||
recalculatePromptTokens('img2img_prompt')
|
||||
recalculatePromptTokens('img2img_neg_prompt')
|
||||
return args_to_array(arguments);
|
||||
}
|
||||
|
||||
|
||||
opts = {}
|
||||
onUiUpdate(function(){
|
||||
if(Object.keys(opts).length != 0) return;
|
||||
|
@ -232,14 +254,12 @@ onUiUpdate(function(){
|
|||
return
|
||||
}
|
||||
|
||||
|
||||
prompt.parentElement.insertBefore(counter, prompt)
|
||||
counter.classList.add("token-counter")
|
||||
prompt.parentElement.style.position = "relative"
|
||||
|
||||
textarea.addEventListener("input", function(){
|
||||
update_token_counter(id_button);
|
||||
});
|
||||
promptTokecountUpdateFuncs[id] = function(){ update_token_counter(id_button); }
|
||||
textarea.addEventListener("input", promptTokecountUpdateFuncs[id]);
|
||||
}
|
||||
|
||||
registerTextarea('txt2img_prompt', 'txt2img_token_counter', 'txt2img_token_button')
|
||||
|
@ -273,7 +293,7 @@ onOptionsChanged(function(){
|
|||
|
||||
let txt2img_textarea, img2img_textarea = undefined;
|
||||
let wait_time = 800
|
||||
let token_timeout;
|
||||
let token_timeouts = {};
|
||||
|
||||
function update_txt2img_tokens(...args) {
|
||||
update_token_counter("txt2img_token_button")
|
||||
|
@ -290,9 +310,9 @@ function update_img2img_tokens(...args) {
|
|||
}
|
||||
|
||||
function update_token_counter(button_id) {
|
||||
if (token_timeout)
|
||||
clearTimeout(token_timeout);
|
||||
token_timeout = setTimeout(() => gradioApp().getElementById(button_id)?.click(), wait_time);
|
||||
if (token_timeouts[button_id])
|
||||
clearTimeout(token_timeouts[button_id]);
|
||||
token_timeouts[button_id] = setTimeout(() => gradioApp().getElementById(button_id)?.click(), wait_time);
|
||||
}
|
||||
|
||||
function restart_reload(){
|
||||
|
@ -309,3 +329,10 @@ function updateInput(target){
|
|||
Object.defineProperty(e, "target", {value: target})
|
||||
target.dispatchEvent(e);
|
||||
}
|
||||
|
||||
|
||||
var desiredCheckpointName = null;
|
||||
function selectCheckpoint(name){
|
||||
desiredCheckpointName = name;
|
||||
gradioApp().getElementById('change_checkpoint').click()
|
||||
}
|
||||
|
|
45
launch.py
45
launch.py
|
@ -17,6 +17,37 @@ stored_commit_hash = None
|
|||
skip_install = False
|
||||
|
||||
|
||||
def check_python_version():
|
||||
is_windows = platform.system() == "Windows"
|
||||
major = sys.version_info.major
|
||||
minor = sys.version_info.minor
|
||||
micro = sys.version_info.micro
|
||||
|
||||
if is_windows:
|
||||
supported_minors = [10]
|
||||
else:
|
||||
supported_minors = [7, 8, 9, 10, 11]
|
||||
|
||||
if not (major == 3 and minor in supported_minors):
|
||||
import modules.errors
|
||||
|
||||
modules.errors.print_error_explanation(f"""
|
||||
INCOMPATIBLE PYTHON VERSION
|
||||
|
||||
This program is tested with 3.10.6 Python, but you have {major}.{minor}.{micro}.
|
||||
If you encounter an error with "RuntimeError: Couldn't install torch." message,
|
||||
or any other error regarding unsuccessful package (library) installation,
|
||||
please downgrade (or upgrade) to the latest version of 3.10 Python
|
||||
and delete current Python and "venv" folder in WebUI's directory.
|
||||
|
||||
You can download 3.10 Python from here: https://www.python.org/downloads/release/python-3109/
|
||||
|
||||
{"Alternatively, use a binary release of WebUI: https://github.com/AUTOMATIC1111/stable-diffusion-webui/releases" if is_windows else ""}
|
||||
|
||||
Use --skip-python-version-check to suppress this warning.
|
||||
""")
|
||||
|
||||
|
||||
def commit_hash():
|
||||
global stored_commit_hash
|
||||
|
||||
|
@ -188,12 +219,11 @@ def run_extensions_installers(settings_file):
|
|||
def prepare_environment():
|
||||
global skip_install
|
||||
|
||||
pip_installer_location = os.environ.get('PIP_INSTALLER_LOCATION', None)
|
||||
|
||||
torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 --extra-index-url https://download.pytorch.org/whl/cu117")
|
||||
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
|
||||
commandline_args = os.environ.get('COMMANDLINE_ARGS', "")
|
||||
|
||||
xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.16rc425')
|
||||
gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "git+https://github.com/TencentARC/GFPGAN.git@8d2447a2d918f8eba5a4a01463fd48e45126a379")
|
||||
clip_package = os.environ.get('CLIP_PACKAGE', "git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1")
|
||||
openclip_package = os.environ.get('OPENCLIP_PACKAGE', "git+https://github.com/mlfoundations/open_clip.git@bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b")
|
||||
|
@ -218,6 +248,7 @@ def prepare_environment():
|
|||
|
||||
sys.argv, _ = extract_arg(sys.argv, '-f')
|
||||
sys.argv, skip_torch_cuda_test = extract_arg(sys.argv, '--skip-torch-cuda-test')
|
||||
sys.argv, skip_python_version_check = extract_arg(sys.argv, '--skip-python-version-check')
|
||||
sys.argv, reinstall_xformers = extract_arg(sys.argv, '--reinstall-xformers')
|
||||
sys.argv, reinstall_torch = extract_arg(sys.argv, '--reinstall-torch')
|
||||
sys.argv, update_check = extract_arg(sys.argv, '--update-check')
|
||||
|
@ -226,14 +257,14 @@ def prepare_environment():
|
|||
xformers = '--xformers' in sys.argv
|
||||
ngrok = '--ngrok' in sys.argv
|
||||
|
||||
if not skip_python_version_check:
|
||||
check_python_version()
|
||||
|
||||
commit = commit_hash()
|
||||
|
||||
print(f"Python {sys.version}")
|
||||
print(f"Commit hash: {commit}")
|
||||
|
||||
if pip_installer_location is not None and not is_installed("pip"):
|
||||
run(f'"{python}" "{pip_installer_location}"', "Installing pip", "Couldn't install pip")
|
||||
|
||||
if reinstall_torch or not is_installed("torch") or not is_installed("torchvision"):
|
||||
run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch", live=True)
|
||||
|
||||
|
@ -252,14 +283,14 @@ def prepare_environment():
|
|||
if (not is_installed("xformers") or reinstall_xformers) and xformers:
|
||||
if platform.system() == "Windows":
|
||||
if platform.python_version().startswith("3.10"):
|
||||
run_pip(f"install -U -I --no-deps xformers==0.0.16rc425", "xformers")
|
||||
run_pip(f"install -U -I --no-deps {xformers_package}", "xformers")
|
||||
else:
|
||||
print("Installation of xformers is not supported in this version of Python.")
|
||||
print("You can also check this and build manually: https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers#building-xformers-on-windows-by-duckness")
|
||||
if not is_installed("xformers"):
|
||||
exit(0)
|
||||
elif platform.system() == "Linux":
|
||||
run_pip("install xformers==0.0.16rc425", "xformers")
|
||||
run_pip(f"install {xformers_package}", "xformers")
|
||||
|
||||
if not is_installed("pyngrok") and ngrok:
|
||||
run_pip("install pyngrok", "ngrok")
|
||||
|
|
|
@ -18,7 +18,8 @@ from modules.textual_inversion.textual_inversion import create_embedding, train_
|
|||
from modules.textual_inversion.preprocess import preprocess
|
||||
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
|
||||
from PIL import PngImagePlugin,Image
|
||||
from modules.sd_models import checkpoints_list, find_checkpoint_config
|
||||
from modules.sd_models import checkpoints_list
|
||||
from modules.sd_models_config import find_checkpoint_config_near_filename
|
||||
from modules.realesrgan_model import get_realesrgan_models
|
||||
from modules import devices
|
||||
from typing import List
|
||||
|
@ -387,7 +388,7 @@ class Api:
|
|||
]
|
||||
|
||||
def get_sd_models(self):
|
||||
return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config(x)} for x in checkpoints_list.values()]
|
||||
return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config_near_filename(x)} for x in checkpoints_list.values()]
|
||||
|
||||
def get_hypernetworks(self):
|
||||
return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks]
|
||||
|
|
|
@ -228,7 +228,7 @@ class SDModelItem(BaseModel):
|
|||
hash: Optional[str] = Field(title="Short hash")
|
||||
sha256: Optional[str] = Field(title="sha256 hash")
|
||||
filename: str = Field(title="Filename")
|
||||
config: str = Field(title="Config file")
|
||||
config: Optional[str] = Field(title="Config file")
|
||||
|
||||
class HypernetworkItem(BaseModel):
|
||||
name: str = Field(title="Name")
|
||||
|
|
|
@ -8,7 +8,7 @@ import torch
|
|||
import modules.face_restoration
|
||||
import modules.shared
|
||||
from modules import shared, devices, modelloader
|
||||
from modules.paths import script_path, models_path
|
||||
from modules.paths import models_path
|
||||
|
||||
# codeformer people made a choice to include modified basicsr library to their project which makes
|
||||
# it utterly impossible to use it alongside with other libraries that also use basicsr, like GFPGAN.
|
||||
|
|
|
@ -2,6 +2,8 @@ import torch
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from modules import devices
|
||||
|
||||
# see https://github.com/AUTOMATIC1111/TorchDeepDanbooru for more
|
||||
|
||||
|
||||
|
@ -196,7 +198,7 @@ class DeepDanbooruModel(nn.Module):
|
|||
t_358, = inputs
|
||||
t_359 = t_358.permute(*[0, 3, 1, 2])
|
||||
t_359_padded = F.pad(t_359, [2, 3, 2, 3], value=0)
|
||||
t_360 = self.n_Conv_0(t_359_padded)
|
||||
t_360 = self.n_Conv_0(t_359_padded.to(self.n_Conv_0.bias.dtype) if devices.unet_needs_upcast else t_359_padded)
|
||||
t_361 = F.relu(t_360)
|
||||
t_361 = F.pad(t_361, [0, 1, 0, 1], value=float('-inf'))
|
||||
t_362 = self.n_MaxPool_0(t_361)
|
||||
|
|
|
@ -16,6 +16,10 @@ def has_mps() -> bool:
|
|||
except Exception:
|
||||
return False
|
||||
|
||||
def has_dml():
|
||||
import importlib
|
||||
loader = importlib.find_loader('torch_directml')
|
||||
return loader is not None
|
||||
|
||||
def extract_device_id(args, name):
|
||||
for x in range(len(args)):
|
||||
|
@ -34,14 +38,25 @@ def get_cuda_device_string():
|
|||
return "cuda"
|
||||
|
||||
|
||||
def get_optimal_device():
|
||||
if torch.cuda.is_available():
|
||||
return torch.device(get_cuda_device_string())
|
||||
def get_optimal_device_name():
|
||||
if has_dml():
|
||||
return "dml"
|
||||
|
||||
if has_mps():
|
||||
return torch.device("mps")
|
||||
return "mps"
|
||||
|
||||
return cpu
|
||||
if torch.cuda.is_available():
|
||||
return get_cuda_device_string()
|
||||
|
||||
return "cpu"
|
||||
|
||||
|
||||
def get_optimal_device():
|
||||
if get_optimal_device_name() == "dml":
|
||||
import torch_directml
|
||||
return torch_directml.device()
|
||||
|
||||
return torch.device(get_optimal_device_name())
|
||||
|
||||
|
||||
def get_device_for(task):
|
||||
|
@ -79,6 +94,16 @@ cpu = torch.device("cpu")
|
|||
device = device_interrogate = device_gfpgan = device_esrgan = device_codeformer = None
|
||||
dtype = torch.float16
|
||||
dtype_vae = torch.float16
|
||||
dtype_unet = torch.float16
|
||||
unet_needs_upcast = False
|
||||
|
||||
|
||||
def cond_cast_unet(input):
|
||||
return input.to(dtype_unet) if unet_needs_upcast else input
|
||||
|
||||
|
||||
def cond_cast_float(input):
|
||||
return input.float() if unet_needs_upcast else input
|
||||
|
||||
|
||||
def randn(seed, shape):
|
||||
|
@ -106,6 +131,10 @@ def autocast(disable=False):
|
|||
return torch.autocast("cuda")
|
||||
|
||||
|
||||
def without_autocast(disable=False):
|
||||
return torch.autocast("cuda", enabled=False) if torch.is_autocast_enabled() and not disable else contextlib.nullcontext()
|
||||
|
||||
|
||||
class NansException(Exception):
|
||||
pass
|
||||
|
||||
|
@ -123,7 +152,7 @@ def test_for_nans(x, where):
|
|||
message = "A tensor with all NaNs was produced in Unet."
|
||||
|
||||
if not shared.cmd_opts.no_half:
|
||||
message += " This could be either because there's not enough precision to represent the picture, or because your video card does not support half type. Try using --no-half commandline argument to fix this."
|
||||
message += " This could be either because there's not enough precision to represent the picture, or because your video card does not support half type. Try setting the \"Upcast cross attention layer to float32\" option in Settings > Stable Diffusion or using the --no-half commandline argument to fix this."
|
||||
|
||||
elif where == "vae":
|
||||
message = "A tensor with all NaNs was produced in VAE."
|
||||
|
@ -133,6 +162,8 @@ def test_for_nans(x, where):
|
|||
else:
|
||||
message = "A tensor with all NaNs was produced."
|
||||
|
||||
message += " Use --disable-nan-check commandline argument to disable this check."
|
||||
|
||||
raise NansException(message)
|
||||
|
||||
|
||||
|
@ -187,6 +218,22 @@ if has_mps():
|
|||
cumsum_needs_bool_fix = not torch.BoolTensor([True,True]).to(device=torch.device("mps"), dtype=torch.int64).equal(torch.BoolTensor([True,False]).to(torch.device("mps")).cumsum(0))
|
||||
torch.cumsum = lambda input, *args, **kwargs: ( cumsum_fix(input, orig_cumsum, *args, **kwargs) )
|
||||
torch.Tensor.cumsum = lambda self, *args, **kwargs: ( cumsum_fix(self, orig_Tensor_cumsum, *args, **kwargs) )
|
||||
orig_narrow = torch.narrow
|
||||
torch.narrow = lambda *args, **kwargs: ( orig_narrow(*args, **kwargs).clone() )
|
||||
|
||||
if has_dml():
|
||||
_cumsum = torch.cumsum
|
||||
_repeat_interleave = torch.repeat_interleave
|
||||
_multinomial = torch.multinomial
|
||||
|
||||
_Tensor_new = torch.Tensor.new
|
||||
_Tensor_cumsum = torch.Tensor.cumsum
|
||||
_Tensor_repeat_interleave = torch.Tensor.repeat_interleave
|
||||
_Tensor_multinomial = torch.Tensor.multinomial
|
||||
|
||||
torch.cumsum = lambda input, *args, **kwargs: ( _cumsum(input.to("cpu"), *args, **kwargs).to(input.device) )
|
||||
torch.repeat_interleave = lambda input, *args, **kwargs: ( _repeat_interleave(input.to("cpu"), *args, **kwargs).to(input.device) )
|
||||
torch.multinomial = lambda input, *args, **kwargs: ( _multinomial(input.to("cpu"), *args, **kwargs).to(input.device) )
|
||||
|
||||
torch.Tensor.new = lambda self, *args, **kwargs: ( _Tensor_new(self.to("cpu"), *args, **kwargs).to(self.device) )
|
||||
torch.Tensor.cumsum = lambda self, *args, **kwargs: ( _Tensor_cumsum(self.to("cpu"), *args, **kwargs).to(self.device) )
|
||||
torch.Tensor.repeat_interleave = lambda self, *args, **kwargs: ( _Tensor_repeat_interleave(self.to("cpu"), *args, **kwargs).to(self.device) )
|
||||
torch.Tensor.multinomial = lambda self, *args, **kwargs: ( _Tensor_multinomial(self.to("cpu"), *args, **kwargs).to(self.device) )
|
|
@ -7,9 +7,11 @@ import git
|
|||
from modules import paths, shared
|
||||
|
||||
extensions = []
|
||||
extensions_dir = os.path.join(paths.script_path, "extensions")
|
||||
extensions_dir = os.path.join(paths.data_path, "extensions")
|
||||
extensions_builtin_dir = os.path.join(paths.script_path, "extensions-builtin")
|
||||
|
||||
if not os.path.exists(extensions_dir):
|
||||
os.makedirs(extensions_dir)
|
||||
|
||||
def active():
|
||||
return [x for x in extensions if x.enabled]
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from modules import extra_networks
|
||||
from modules import extra_networks, shared, extra_networks
|
||||
from modules.hypernetworks import hypernetwork
|
||||
|
||||
|
||||
|
@ -7,6 +7,12 @@ class ExtraNetworkHypernet(extra_networks.ExtraNetwork):
|
|||
super().__init__('hypernet')
|
||||
|
||||
def activate(self, p, params_list):
|
||||
additional = shared.opts.sd_hypernetwork
|
||||
|
||||
if additional != "" and additional in shared.hypernetworks and len([x for x in params_list if x.items[0] == additional]) == 0:
|
||||
p.all_prompts = [x + f"<hypernet:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts]
|
||||
params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
|
||||
|
||||
names = []
|
||||
multipliers = []
|
||||
for params in params_list:
|
||||
|
|
|
@ -6,7 +6,7 @@ import shutil
|
|||
import torch
|
||||
import tqdm
|
||||
|
||||
from modules import shared, images, sd_models, sd_vae
|
||||
from modules import shared, images, sd_models, sd_vae, sd_models_config
|
||||
from modules.ui_common import plaintext_to_html
|
||||
import gradio as gr
|
||||
import safetensors.torch
|
||||
|
@ -37,7 +37,7 @@ def run_pnginfo(image):
|
|||
|
||||
def create_config(ckpt_result, config_source, a, b, c):
|
||||
def config(x):
|
||||
res = sd_models.find_checkpoint_config(x) if x else None
|
||||
res = sd_models_config.find_checkpoint_config_near_filename(x) if x else None
|
||||
return res if res != shared.sd_default_config else None
|
||||
|
||||
if config_source == 0:
|
||||
|
@ -132,6 +132,7 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
|
|||
tertiary_model_info = sd_models.checkpoints_list[tertiary_model_name] if theta_func1 else None
|
||||
|
||||
result_is_inpainting_model = False
|
||||
result_is_instruct_pix2pix_model = False
|
||||
|
||||
if theta_func2:
|
||||
shared.state.textinfo = f"Loading B"
|
||||
|
@ -185,14 +186,19 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
|
|||
if a.shape != b.shape and a.shape[0:1] + a.shape[2:] == b.shape[0:1] + b.shape[2:]:
|
||||
if a.shape[1] == 4 and b.shape[1] == 9:
|
||||
raise RuntimeError("When merging inpainting model with a normal one, A must be the inpainting model.")
|
||||
if a.shape[1] == 4 and b.shape[1] == 8:
|
||||
raise RuntimeError("When merging instruct-pix2pix model with a normal one, A must be the instruct-pix2pix model.")
|
||||
|
||||
assert a.shape[1] == 9 and b.shape[1] == 4, f"Bad dimensions for merged layer {key}: A={a.shape}, B={b.shape}"
|
||||
|
||||
theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier)
|
||||
result_is_inpainting_model = True
|
||||
if a.shape[1] == 8 and b.shape[1] == 4:#If we have an Instruct-Pix2Pix model...
|
||||
theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier)#Merge only the vectors the models have in common. Otherwise we get an error due to dimension mismatch.
|
||||
result_is_instruct_pix2pix_model = True
|
||||
else:
|
||||
assert a.shape[1] == 9 and b.shape[1] == 4, f"Bad dimensions for merged layer {key}: A={a.shape}, B={b.shape}"
|
||||
theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier)
|
||||
result_is_inpainting_model = True
|
||||
else:
|
||||
theta_0[key] = theta_func2(a, b, multiplier)
|
||||
|
||||
|
||||
theta_0[key] = to_half(theta_0[key], save_as_half)
|
||||
|
||||
shared.state.sampling_step += 1
|
||||
|
@ -226,6 +232,7 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
|
|||
|
||||
filename = filename_generator() if custom_name == '' else custom_name
|
||||
filename += ".inpainting" if result_is_inpainting_model else ""
|
||||
filename += ".instruct-pix2pix" if result_is_instruct_pix2pix_model else ""
|
||||
filename += "." + checkpoint_format
|
||||
|
||||
output_modelname = os.path.join(ckpt_dir, filename)
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import base64
|
||||
import html
|
||||
import io
|
||||
import math
|
||||
import os
|
||||
|
@ -6,24 +7,33 @@ import re
|
|||
from pathlib import Path
|
||||
|
||||
import gradio as gr
|
||||
from modules.shared import script_path
|
||||
from modules.paths import data_path
|
||||
from modules import shared, ui_tempdir, script_callbacks
|
||||
import tempfile
|
||||
from PIL import Image
|
||||
|
||||
re_param_code = r'\s*([\w ]+):\s*("(?:\\|\"|[^\"])+"|[^,]*)(?:,|$)'
|
||||
re_param_code = r'\s*([\w ]+):\s*("(?:\\"[^,]|\\"|\\|[^\"])+"|[^,]*)(?:,|$)'
|
||||
re_param = re.compile(re_param_code)
|
||||
re_params = re.compile(r"^(?:" + re_param_code + "){3,}$")
|
||||
re_imagesize = re.compile(r"^(\d+)x(\d+)$")
|
||||
re_hypernet_hash = re.compile("\(([0-9a-f]+)\)$")
|
||||
type_of_gr_update = type(gr.update())
|
||||
|
||||
paste_fields = {}
|
||||
bind_list = []
|
||||
registered_param_bindings = []
|
||||
|
||||
|
||||
class ParamBinding:
|
||||
def __init__(self, paste_button, tabname, source_text_component=None, source_image_component=None, source_tabname=None, override_settings_component=None):
|
||||
self.paste_button = paste_button
|
||||
self.tabname = tabname
|
||||
self.source_text_component = source_text_component
|
||||
self.source_image_component = source_image_component
|
||||
self.source_tabname = source_tabname
|
||||
self.override_settings_component = override_settings_component
|
||||
|
||||
|
||||
def reset():
|
||||
paste_fields.clear()
|
||||
bind_list.clear()
|
||||
|
||||
|
||||
def quote(text):
|
||||
|
@ -75,26 +85,6 @@ def add_paste_fields(tabname, init_img, fields):
|
|||
modules.ui.img2img_paste_fields = fields
|
||||
|
||||
|
||||
def integrate_settings_paste_fields(component_dict):
|
||||
from modules import ui
|
||||
|
||||
settings_map = {
|
||||
'CLIP_stop_at_last_layers': 'Clip skip',
|
||||
'inpainting_mask_weight': 'Conditional mask weight',
|
||||
'sd_model_checkpoint': 'Model hash',
|
||||
'eta_noise_seed_delta': 'ENSD',
|
||||
'initial_noise_multiplier': 'Noise multiplier',
|
||||
}
|
||||
settings_paste_fields = [
|
||||
(component_dict[k], lambda d, k=k, v=v: ui.apply_setting(k, d.get(v, None)))
|
||||
for k, v in settings_map.items()
|
||||
]
|
||||
|
||||
for tabname, info in paste_fields.items():
|
||||
if info["fields"] is not None:
|
||||
info["fields"] += settings_paste_fields
|
||||
|
||||
|
||||
def create_buttons(tabs_list):
|
||||
buttons = {}
|
||||
for tab in tabs_list:
|
||||
|
@ -102,9 +92,60 @@ def create_buttons(tabs_list):
|
|||
return buttons
|
||||
|
||||
|
||||
#if send_generate_info is a tab name, mean generate_info comes from the params fields of the tab
|
||||
def bind_buttons(buttons, send_image, send_generate_info):
|
||||
bind_list.append([buttons, send_image, send_generate_info])
|
||||
"""old function for backwards compatibility; do not use this, use register_paste_params_button"""
|
||||
for tabname, button in buttons.items():
|
||||
source_text_component = send_generate_info if isinstance(send_generate_info, gr.components.Component) else None
|
||||
source_tabname = send_generate_info if isinstance(send_generate_info, str) else None
|
||||
|
||||
register_paste_params_button(ParamBinding(paste_button=button, tabname=tabname, source_text_component=source_text_component, source_image_component=send_image, source_tabname=source_tabname))
|
||||
|
||||
|
||||
def register_paste_params_button(binding: ParamBinding):
|
||||
registered_param_bindings.append(binding)
|
||||
|
||||
|
||||
def connect_paste_params_buttons():
|
||||
binding: ParamBinding
|
||||
for binding in registered_param_bindings:
|
||||
destination_image_component = paste_fields[binding.tabname]["init_img"]
|
||||
fields = paste_fields[binding.tabname]["fields"]
|
||||
|
||||
destination_width_component = next(iter([field for field, name in fields if name == "Size-1"] if fields else []), None)
|
||||
destination_height_component = next(iter([field for field, name in fields if name == "Size-2"] if fields else []), None)
|
||||
|
||||
if binding.source_image_component and destination_image_component:
|
||||
if isinstance(binding.source_image_component, gr.Gallery):
|
||||
func = send_image_and_dimensions if destination_width_component else image_from_url_text
|
||||
jsfunc = "extract_image_from_gallery"
|
||||
else:
|
||||
func = send_image_and_dimensions if destination_width_component else lambda x: x
|
||||
jsfunc = None
|
||||
|
||||
binding.paste_button.click(
|
||||
fn=func,
|
||||
_js=jsfunc,
|
||||
inputs=[binding.source_image_component],
|
||||
outputs=[destination_image_component, destination_width_component, destination_height_component] if destination_width_component else [destination_image_component],
|
||||
)
|
||||
|
||||
if binding.source_text_component is not None and fields is not None:
|
||||
connect_paste(binding.paste_button, fields, binding.source_text_component, binding.override_settings_component, binding.tabname)
|
||||
|
||||
if binding.source_tabname is not None and fields is not None:
|
||||
paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration'] + (["Seed"] if shared.opts.send_seed else [])
|
||||
binding.paste_button.click(
|
||||
fn=lambda *x: x,
|
||||
inputs=[field for field, name in paste_fields[binding.source_tabname]["fields"] if name in paste_field_names],
|
||||
outputs=[field for field, name in fields if name in paste_field_names],
|
||||
)
|
||||
|
||||
binding.paste_button.click(
|
||||
fn=None,
|
||||
_js=f"switch_to_{binding.tabname}",
|
||||
inputs=None,
|
||||
outputs=None,
|
||||
)
|
||||
|
||||
|
||||
def send_image_and_dimensions(x):
|
||||
|
@ -123,49 +164,6 @@ def send_image_and_dimensions(x):
|
|||
return img, w, h
|
||||
|
||||
|
||||
def run_bind():
|
||||
for buttons, source_image_component, send_generate_info in bind_list:
|
||||
for tab in buttons:
|
||||
button = buttons[tab]
|
||||
destination_image_component = paste_fields[tab]["init_img"]
|
||||
fields = paste_fields[tab]["fields"]
|
||||
|
||||
destination_width_component = next(iter([field for field, name in fields if name == "Size-1"] if fields else []), None)
|
||||
destination_height_component = next(iter([field for field, name in fields if name == "Size-2"] if fields else []), None)
|
||||
|
||||
if source_image_component and destination_image_component:
|
||||
if isinstance(source_image_component, gr.Gallery):
|
||||
func = send_image_and_dimensions if destination_width_component else image_from_url_text
|
||||
jsfunc = "extract_image_from_gallery"
|
||||
else:
|
||||
func = send_image_and_dimensions if destination_width_component else lambda x: x
|
||||
jsfunc = None
|
||||
|
||||
button.click(
|
||||
fn=func,
|
||||
_js=jsfunc,
|
||||
inputs=[source_image_component],
|
||||
outputs=[destination_image_component, destination_width_component, destination_height_component] if destination_width_component else [destination_image_component],
|
||||
)
|
||||
|
||||
if send_generate_info and fields is not None:
|
||||
if send_generate_info in paste_fields:
|
||||
paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration'] + (["Seed"] if shared.opts.send_seed else [])
|
||||
button.click(
|
||||
fn=lambda *x: x,
|
||||
inputs=[field for field, name in paste_fields[send_generate_info]["fields"] if name in paste_field_names],
|
||||
outputs=[field for field, name in fields if name in paste_field_names],
|
||||
)
|
||||
else:
|
||||
connect_paste(button, fields, send_generate_info)
|
||||
|
||||
button.click(
|
||||
fn=None,
|
||||
_js=f"switch_to_{tab}",
|
||||
inputs=None,
|
||||
outputs=None,
|
||||
)
|
||||
|
||||
|
||||
def find_hypernetwork_key(hypernet_name, hypernet_hash=None):
|
||||
"""Determines the config parameter name to use for the hypernet based on the parameters in the infotext.
|
||||
|
@ -243,7 +241,7 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
|
|||
done_with_prompt = False
|
||||
|
||||
*lines, lastline = x.strip().split("\n")
|
||||
if not re_params.match(lastline):
|
||||
if len(re_param.findall(lastline)) < 3:
|
||||
lines.append(lastline)
|
||||
lastline = ''
|
||||
|
||||
|
@ -262,6 +260,7 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
|
|||
res["Negative prompt"] = negative_prompt
|
||||
|
||||
for k, v in re_param.findall(lastline):
|
||||
v = v[1:-1] if v[0] == '"' and v[-1] == '"' else v
|
||||
m = re_imagesize.match(v)
|
||||
if m is not None:
|
||||
res[k+"-1"] = m.group(1)
|
||||
|
@ -286,10 +285,53 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
|
|||
return res
|
||||
|
||||
|
||||
def connect_paste(button, paste_fields, input_comp, jsfunc=None):
|
||||
settings_map = {}
|
||||
|
||||
infotext_to_setting_name_mapping = [
|
||||
('Clip skip', 'CLIP_stop_at_last_layers', ),
|
||||
('Conditional mask weight', 'inpainting_mask_weight'),
|
||||
('Model hash', 'sd_model_checkpoint'),
|
||||
('ENSD', 'eta_noise_seed_delta'),
|
||||
('Noise multiplier', 'initial_noise_multiplier'),
|
||||
('Eta', 'eta_ancestral'),
|
||||
('Eta DDIM', 'eta_ddim'),
|
||||
('Discard penultimate sigma', 'always_discard_next_to_last_sigma')
|
||||
]
|
||||
|
||||
|
||||
def create_override_settings_dict(text_pairs):
|
||||
"""creates processing's override_settings parameters from gradio's multiselect
|
||||
|
||||
Example input:
|
||||
['Clip skip: 2', 'Model hash: e6e99610c4', 'ENSD: 31337']
|
||||
|
||||
Example output:
|
||||
{'CLIP_stop_at_last_layers': 2, 'sd_model_checkpoint': 'e6e99610c4', 'eta_noise_seed_delta': 31337}
|
||||
"""
|
||||
|
||||
res = {}
|
||||
|
||||
params = {}
|
||||
for pair in text_pairs:
|
||||
k, v = pair.split(":", maxsplit=1)
|
||||
|
||||
params[k] = v.strip()
|
||||
|
||||
for param_name, setting_name in infotext_to_setting_name_mapping:
|
||||
value = params.get(param_name, None)
|
||||
|
||||
if value is None:
|
||||
continue
|
||||
|
||||
res[setting_name] = shared.opts.cast_value(setting_name, value)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def connect_paste(button, paste_fields, input_comp, override_settings_component, tabname):
|
||||
def paste_func(prompt):
|
||||
if not prompt and not shared.cmd_opts.hide_ui_dir_config:
|
||||
filename = os.path.join(script_path, "params.txt")
|
||||
filename = os.path.join(data_path, "params.txt")
|
||||
if os.path.exists(filename):
|
||||
with open(filename, "r", encoding="utf8") as file:
|
||||
prompt = file.read()
|
||||
|
@ -323,9 +365,35 @@ def connect_paste(button, paste_fields, input_comp, jsfunc=None):
|
|||
|
||||
return res
|
||||
|
||||
if override_settings_component is not None:
|
||||
def paste_settings(params):
|
||||
vals = {}
|
||||
|
||||
for param_name, setting_name in infotext_to_setting_name_mapping:
|
||||
v = params.get(param_name, None)
|
||||
if v is None:
|
||||
continue
|
||||
|
||||
if setting_name == "sd_model_checkpoint" and shared.opts.disable_weights_auto_swap:
|
||||
continue
|
||||
|
||||
v = shared.opts.cast_value(setting_name, v)
|
||||
current_value = getattr(shared.opts, setting_name, None)
|
||||
|
||||
if v == current_value:
|
||||
continue
|
||||
|
||||
vals[param_name] = v
|
||||
|
||||
vals_pairs = [f"{k}: {v}" for k, v in vals.items()]
|
||||
|
||||
return gr.Dropdown.update(value=vals_pairs, choices=vals_pairs, visible=len(vals_pairs) > 0)
|
||||
|
||||
paste_fields = paste_fields + [(override_settings_component, paste_settings)]
|
||||
|
||||
button.click(
|
||||
fn=paste_func,
|
||||
_js=jsfunc,
|
||||
_js=f"recalculate_prompts_{tabname}",
|
||||
inputs=[input_comp],
|
||||
outputs=[x[0] for x in paste_fields],
|
||||
)
|
||||
|
|
|
@ -6,12 +6,11 @@ import facexlib
|
|||
import gfpgan
|
||||
|
||||
import modules.face_restoration
|
||||
from modules import shared, devices, modelloader
|
||||
from modules.paths import models_path
|
||||
from modules import paths, shared, devices, modelloader
|
||||
|
||||
model_dir = "GFPGAN"
|
||||
user_path = None
|
||||
model_path = os.path.join(models_path, model_dir)
|
||||
model_path = os.path.join(paths.models_path, model_dir)
|
||||
model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
|
||||
have_gfpgan = False
|
||||
loaded_gfpgan_model = None
|
||||
|
|
|
@ -4,8 +4,11 @@ import os.path
|
|||
|
||||
import filelock
|
||||
|
||||
from modules import shared
|
||||
from modules.paths import data_path
|
||||
|
||||
cache_filename = "cache.json"
|
||||
|
||||
cache_filename = os.path.join(data_path, "cache.json")
|
||||
cache_data = None
|
||||
|
||||
|
||||
|
@ -66,6 +69,9 @@ def sha256(filename, title):
|
|||
if sha256_value is not None:
|
||||
return sha256_value
|
||||
|
||||
if shared.cmd_opts.no_hashing:
|
||||
return None
|
||||
|
||||
print(f"Calculating sha256 for {filename}: ", end='')
|
||||
sha256_value = calculate_sha256(filename)
|
||||
print(f"{sha256_value}")
|
||||
|
|
|
@ -307,7 +307,7 @@ class Hypernetwork:
|
|||
def shorthash(self):
|
||||
sha256 = hashes.sha256(self.filename, f'hypernet/{self.name}')
|
||||
|
||||
return sha256[0:10]
|
||||
return sha256[0:10] if sha256 else None
|
||||
|
||||
|
||||
def list_hypernetworks(path):
|
||||
|
|
|
@ -16,6 +16,7 @@ from PIL import Image, ImageFont, ImageDraw, PngImagePlugin
|
|||
from fonts.ttf import Roboto
|
||||
import string
|
||||
import json
|
||||
import hashlib
|
||||
|
||||
from modules import sd_samplers, shared, script_callbacks
|
||||
from modules.shared import opts, cmd_opts
|
||||
|
@ -36,6 +37,8 @@ def image_grid(imgs, batch_size=1, rows=None):
|
|||
else:
|
||||
rows = math.sqrt(len(imgs))
|
||||
rows = round(rows)
|
||||
if rows > len(imgs):
|
||||
rows = len(imgs)
|
||||
|
||||
cols = math.ceil(len(imgs) / rows)
|
||||
|
||||
|
@ -128,7 +131,7 @@ class GridAnnotation:
|
|||
self.size = None
|
||||
|
||||
|
||||
def draw_grid_annotations(im, width, height, hor_texts, ver_texts):
|
||||
def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0):
|
||||
def wrap(drawing, text, font, line_length):
|
||||
lines = ['']
|
||||
for word in text.split():
|
||||
|
@ -192,32 +195,35 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts):
|
|||
line.allowed_width = allowed_width
|
||||
|
||||
hor_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing for lines in hor_texts]
|
||||
ver_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing * len(lines) for lines in
|
||||
ver_texts]
|
||||
ver_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing * len(lines) for lines in ver_texts]
|
||||
|
||||
pad_top = max(hor_text_heights) + line_spacing * 2
|
||||
pad_top = 0 if sum(hor_text_heights) == 0 else max(hor_text_heights) + line_spacing * 2
|
||||
|
||||
result = Image.new("RGB", (im.width + pad_left, im.height + pad_top), "white")
|
||||
result.paste(im, (pad_left, pad_top))
|
||||
result = Image.new("RGB", (im.width + pad_left + margin * (cols-1), im.height + pad_top + margin * (rows-1)), "white")
|
||||
|
||||
for row in range(rows):
|
||||
for col in range(cols):
|
||||
cell = im.crop((width * col, height * row, width * (col+1), height * (row+1)))
|
||||
result.paste(cell, (pad_left + (width + margin) * col, pad_top + (height + margin) * row))
|
||||
|
||||
d = ImageDraw.Draw(result)
|
||||
|
||||
for col in range(cols):
|
||||
x = pad_left + width * col + width / 2
|
||||
x = pad_left + (width + margin) * col + width / 2
|
||||
y = pad_top / 2 - hor_text_heights[col] / 2
|
||||
|
||||
draw_texts(d, x, y, hor_texts[col], fnt, fontsize)
|
||||
|
||||
for row in range(rows):
|
||||
x = pad_left / 2
|
||||
y = pad_top + height * row + height / 2 - ver_text_heights[row] / 2
|
||||
y = pad_top + (height + margin) * row + height / 2 - ver_text_heights[row] / 2
|
||||
|
||||
draw_texts(d, x, y, ver_texts[row], fnt, fontsize)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def draw_prompt_matrix(im, width, height, all_prompts):
|
||||
def draw_prompt_matrix(im, width, height, all_prompts, margin=0):
|
||||
prompts = all_prompts[1:]
|
||||
boundary = math.ceil(len(prompts) / 2)
|
||||
|
||||
|
@ -227,7 +233,7 @@ def draw_prompt_matrix(im, width, height, all_prompts):
|
|||
hor_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_horiz)] for pos in range(1 << len(prompts_horiz))]
|
||||
ver_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_vert)] for pos in range(1 << len(prompts_vert))]
|
||||
|
||||
return draw_grid_annotations(im, width, height, hor_texts, ver_texts)
|
||||
return draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin)
|
||||
|
||||
|
||||
def resize_image(resize_mode, im, width, height, upscaler_name=None):
|
||||
|
@ -338,6 +344,7 @@ class FilenameGenerator:
|
|||
'date': lambda self: datetime.datetime.now().strftime('%Y-%m-%d'),
|
||||
'datetime': lambda self, *args: self.datetime(*args), # accepts formats: [datetime], [datetime<Format>], [datetime<Format><Time Zone>]
|
||||
'job_timestamp': lambda self: getattr(self.p, "job_timestamp", shared.state.job_timestamp),
|
||||
'prompt_hash': lambda self: hashlib.sha256(self.prompt.encode()).hexdigest()[0:8],
|
||||
'prompt': lambda self: sanitize_filename_part(self.prompt),
|
||||
'prompt_no_styles': lambda self: self.prompt_no_style(),
|
||||
'prompt_spaces': lambda self: sanitize_filename_part(self.prompt, replace_spaces=False),
|
||||
|
|
|
@ -7,6 +7,7 @@ import numpy as np
|
|||
from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops
|
||||
|
||||
from modules import devices, sd_samplers
|
||||
from modules.generation_parameters_copypaste import create_override_settings_dict
|
||||
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
|
||||
from modules.shared import opts, state
|
||||
import modules.shared as shared
|
||||
|
@ -16,11 +17,18 @@ import modules.images as images
|
|||
import modules.scripts
|
||||
|
||||
|
||||
def process_batch(p, input_dir, output_dir, args):
|
||||
def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args):
|
||||
processing.fix_seed(p)
|
||||
|
||||
images = shared.listfiles(input_dir)
|
||||
|
||||
is_inpaint_batch = False
|
||||
if inpaint_mask_dir:
|
||||
inpaint_masks = shared.listfiles(inpaint_mask_dir)
|
||||
is_inpaint_batch = len(inpaint_masks) > 0
|
||||
if is_inpaint_batch:
|
||||
print(f"\nInpaint batch is enabled. {len(inpaint_masks)} masks found.")
|
||||
|
||||
print(f"Will process {len(images)} images, creating {p.n_iter * p.batch_size} new images for each.")
|
||||
|
||||
save_normally = output_dir == ''
|
||||
|
@ -43,6 +51,15 @@ def process_batch(p, input_dir, output_dir, args):
|
|||
img = ImageOps.exif_transpose(img)
|
||||
p.init_images = [img] * p.batch_size
|
||||
|
||||
if is_inpaint_batch:
|
||||
# try to find corresponding mask for an image using simple filename matching
|
||||
mask_image_path = os.path.join(inpaint_mask_dir, os.path.basename(image))
|
||||
# if not found use first one ("same mask for all images" use-case)
|
||||
if not mask_image_path in inpaint_masks:
|
||||
mask_image_path = inpaint_masks[0]
|
||||
mask_image = Image.open(mask_image_path)
|
||||
p.image_mask = mask_image
|
||||
|
||||
proc = modules.scripts.scripts_img2img.run(p, *args)
|
||||
if proc is None:
|
||||
proc = process_images(p)
|
||||
|
@ -59,7 +76,9 @@ def process_batch(p, input_dir, output_dir, args):
|
|||
processed_image.save(os.path.join(output_dir, filename))
|
||||
|
||||
|
||||
def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args):
|
||||
def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, *args):
|
||||
override_settings = create_override_settings_dict(override_settings_texts)
|
||||
|
||||
is_batch = mode == 5
|
||||
|
||||
if mode == 0: # img2img
|
||||
|
@ -123,9 +142,11 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
|
|||
inpainting_fill=inpainting_fill,
|
||||
resize_mode=resize_mode,
|
||||
denoising_strength=denoising_strength,
|
||||
image_cfg_scale=image_cfg_scale,
|
||||
inpaint_full_res=inpaint_full_res,
|
||||
inpaint_full_res_padding=inpaint_full_res_padding,
|
||||
inpainting_mask_invert=inpainting_mask_invert,
|
||||
override_settings=override_settings,
|
||||
)
|
||||
|
||||
p.scripts = modules.scripts.scripts_txt2img
|
||||
|
@ -139,7 +160,7 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
|
|||
if is_batch:
|
||||
assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled"
|
||||
|
||||
process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, args)
|
||||
process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, args)
|
||||
|
||||
processed = Processed(p, [], p.seed, "")
|
||||
else:
|
||||
|
|
|
@ -12,7 +12,7 @@ from torchvision import transforms
|
|||
from torchvision.transforms.functional import InterpolationMode
|
||||
|
||||
import modules.shared as shared
|
||||
from modules import devices, paths, lowvram, modelloader, errors
|
||||
from modules import devices, paths, shared, lowvram, modelloader, errors
|
||||
|
||||
blip_image_eval_size = 384
|
||||
clip_model_name = 'ViT-L/14'
|
||||
|
|
53
modules/mac_specific.py
Normal file
53
modules/mac_specific.py
Normal file
|
@ -0,0 +1,53 @@
|
|||
import torch
|
||||
from modules import paths
|
||||
from modules.sd_hijack_utils import CondFunc
|
||||
from packaging import version
|
||||
|
||||
|
||||
# has_mps is only available in nightly pytorch (for now) and macOS 12.3+.
|
||||
# check `getattr` and try it for compatibility
|
||||
def check_for_mps() -> bool:
|
||||
if not getattr(torch, 'has_mps', False):
|
||||
return False
|
||||
try:
|
||||
torch.zeros(1).to(torch.device("mps"))
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
has_mps = check_for_mps()
|
||||
|
||||
|
||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/89784
|
||||
def cumsum_fix(input, cumsum_func, *args, **kwargs):
|
||||
if input.device.type == 'mps':
|
||||
output_dtype = kwargs.get('dtype', input.dtype)
|
||||
if output_dtype == torch.int64:
|
||||
return cumsum_func(input.cpu(), *args, **kwargs).to(input.device)
|
||||
elif cumsum_needs_bool_fix and output_dtype == torch.bool or cumsum_needs_int_fix and (output_dtype == torch.int8 or output_dtype == torch.int16):
|
||||
return cumsum_func(input.to(torch.int32), *args, **kwargs).to(torch.int64)
|
||||
return cumsum_func(input, *args, **kwargs)
|
||||
|
||||
|
||||
if has_mps:
|
||||
# MPS fix for randn in torchsde
|
||||
CondFunc('torchsde._brownian.brownian_interval._randn', lambda _, size, dtype, device, seed: torch.randn(size, dtype=dtype, device=torch.device("cpu"), generator=torch.Generator(torch.device("cpu")).manual_seed(int(seed))).to(device), lambda _, size, dtype, device, seed: device.type == 'mps')
|
||||
|
||||
if version.parse(torch.__version__) < version.parse("1.13"):
|
||||
# PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working
|
||||
|
||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/79383
|
||||
CondFunc('torch.Tensor.to', lambda orig_func, self, *args, **kwargs: orig_func(self.contiguous(), *args, **kwargs),
|
||||
lambda _, self, *args, **kwargs: self.device.type != 'mps' and (args and isinstance(args[0], torch.device) and args[0].type == 'mps' or isinstance(kwargs.get('device'), torch.device) and kwargs['device'].type == 'mps'))
|
||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/80800
|
||||
CondFunc('torch.nn.functional.layer_norm', lambda orig_func, *args, **kwargs: orig_func(*([args[0].contiguous()] + list(args[1:])), **kwargs),
|
||||
lambda _, *args, **kwargs: args and isinstance(args[0], torch.Tensor) and args[0].device.type == 'mps')
|
||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/90532
|
||||
CondFunc('torch.Tensor.numpy', lambda orig_func, self, *args, **kwargs: orig_func(self.detach(), *args, **kwargs), lambda _, self, *args, **kwargs: self.requires_grad)
|
||||
elif version.parse(torch.__version__) > version.parse("1.13.1"):
|
||||
cumsum_needs_int_fix = not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.ShortTensor([1,1]).to(torch.device("mps")).cumsum(0))
|
||||
cumsum_needs_bool_fix = not torch.BoolTensor([True,True]).to(device=torch.device("mps"), dtype=torch.int64).equal(torch.BoolTensor([True,False]).to(torch.device("mps")).cumsum(0))
|
||||
cumsum_fix_func = lambda orig_func, input, *args, **kwargs: cumsum_fix(input, orig_func, *args, **kwargs)
|
||||
CondFunc('torch.cumsum', cumsum_fix_func, None)
|
||||
CondFunc('torch.Tensor.cumsum', cumsum_fix_func, None)
|
||||
CondFunc('torch.narrow', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).clone(), None)
|
||||
|
|
@ -45,6 +45,9 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
|
|||
full_path = file
|
||||
if os.path.isdir(full_path):
|
||||
continue
|
||||
if os.path.islink(full_path) and not os.path.exists(full_path):
|
||||
print(f"Skipping broken symlink: {full_path}")
|
||||
continue
|
||||
if ext_blacklist is not None and any([full_path.endswith(x) for x in ext_blacklist]):
|
||||
continue
|
||||
if len(ext_filter) != 0:
|
||||
|
|
1459
modules/models/diffusion/ddpm_edit.py
Normal file
1459
modules/models/diffusion/ddpm_edit.py
Normal file
File diff suppressed because it is too large
Load Diff
|
@ -4,7 +4,15 @@ import sys
|
|||
import modules.safe
|
||||
|
||||
script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
models_path = os.path.join(script_path, "models")
|
||||
|
||||
# Parse the --data-dir flag first so we can use it as a base for our other argument default values
|
||||
parser = argparse.ArgumentParser(add_help=False)
|
||||
parser.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored",)
|
||||
cmd_opts_pre = parser.parse_known_args()[0]
|
||||
data_path = cmd_opts_pre.data_dir
|
||||
models_path = os.path.join(data_path, "models")
|
||||
|
||||
# data_path = cmd_opts_pre.data
|
||||
sys.path.insert(0, script_path)
|
||||
|
||||
# search for directory of stable diffusion in following places
|
||||
|
|
|
@ -13,10 +13,11 @@ from skimage import exposure
|
|||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import modules.sd_hijack
|
||||
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks, extra_networks
|
||||
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks, extra_networks, sd_vae_approx, scripts
|
||||
from modules.sd_hijack import model_hijack
|
||||
from modules.shared import opts, cmd_opts, state
|
||||
import modules.shared as shared
|
||||
import modules.paths as paths
|
||||
import modules.face_restoration
|
||||
import modules.images as images
|
||||
import modules.styles
|
||||
|
@ -184,7 +185,12 @@ class StableDiffusionProcessing:
|
|||
conditioning = 2. * (conditioning - depth_min) / (depth_max - depth_min) - 1.
|
||||
return conditioning
|
||||
|
||||
def inpainting_image_conditioning(self, source_image, latent_image, image_mask = None):
|
||||
def edit_image_conditioning(self, source_image):
|
||||
conditioning_image = self.sd_model.encode_first_stage(source_image).mode()
|
||||
|
||||
return conditioning_image
|
||||
|
||||
def inpainting_image_conditioning(self, source_image, latent_image, image_mask=None):
|
||||
self.is_using_inpainting_conditioning = True
|
||||
|
||||
# Handle the different mask inputs
|
||||
|
@ -203,7 +209,7 @@ class StableDiffusionProcessing:
|
|||
|
||||
# Create another latent image, this time with a masked version of the original input.
|
||||
# Smoothly interpolate between the masked and unmasked latent conditioning image using a parameter.
|
||||
conditioning_mask = conditioning_mask.to(source_image.device).to(source_image.dtype)
|
||||
conditioning_mask = conditioning_mask.to(device=source_image.device, dtype=source_image.dtype)
|
||||
conditioning_image = torch.lerp(
|
||||
source_image,
|
||||
source_image * (1.0 - conditioning_mask),
|
||||
|
@ -222,11 +228,16 @@ class StableDiffusionProcessing:
|
|||
return image_conditioning
|
||||
|
||||
def img2img_image_conditioning(self, source_image, latent_image, image_mask=None):
|
||||
source_image = devices.cond_cast_float(source_image)
|
||||
|
||||
# HACK: Using introspection as the Depth2Image model doesn't appear to uniquely
|
||||
# identify itself with a field common to all models. The conditioning_key is also hybrid.
|
||||
if isinstance(self.sd_model, LatentDepth2ImageDiffusion):
|
||||
return self.depth2img_image_conditioning(source_image)
|
||||
|
||||
if self.sd_model.cond_stage_key == "edit":
|
||||
return self.edit_image_conditioning(source_image)
|
||||
|
||||
if self.sampler.conditioning_key in {'hybrid', 'concat'}:
|
||||
return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask)
|
||||
|
||||
|
@ -257,6 +268,7 @@ class Processed:
|
|||
self.height = p.height
|
||||
self.sampler_name = p.sampler_name
|
||||
self.cfg_scale = p.cfg_scale
|
||||
self.image_cfg_scale = getattr(p, 'image_cfg_scale', None)
|
||||
self.steps = p.steps
|
||||
self.batch_size = p.batch_size
|
||||
self.restore_faces = p.restore_faces
|
||||
|
@ -434,19 +446,17 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
|
|||
"Steps": p.steps,
|
||||
"Sampler": p.sampler_name,
|
||||
"CFG scale": p.cfg_scale,
|
||||
"Image CFG scale": getattr(p, 'image_cfg_scale', None),
|
||||
"Seed": all_seeds[index],
|
||||
"Face restoration": (opts.face_restoration_model if p.restore_faces else None),
|
||||
"Size": f"{p.width}x{p.height}",
|
||||
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
|
||||
"Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
|
||||
"Batch size": (None if p.batch_size < 2 else p.batch_size),
|
||||
"Batch pos": (None if p.batch_size < 2 else position_in_batch),
|
||||
"Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
|
||||
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
|
||||
"Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
|
||||
"Denoising strength": getattr(p, 'denoising_strength', None),
|
||||
"Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None,
|
||||
"Eta": (None if p.sampler is None or p.sampler.eta == p.sampler.default_eta else p.sampler.eta),
|
||||
"Clip skip": None if clip_skip <= 1 else clip_skip,
|
||||
"ENSD": None if opts.eta_noise_seed_delta == 0 else opts.eta_noise_seed_delta,
|
||||
}
|
||||
|
@ -568,10 +578,14 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||
with devices.autocast():
|
||||
p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
|
||||
|
||||
# for OSX, loading the model during sampling changes the generated picture, so it is loaded here
|
||||
if shared.opts.live_previews_enable and opts.show_progress_type == "Approx NN":
|
||||
sd_vae_approx.model()
|
||||
|
||||
if not p.disable_extra_networks:
|
||||
extra_networks.activate(p, extra_network_data)
|
||||
|
||||
with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file:
|
||||
with open(os.path.join(paths.data_path, "params.txt"), "w", encoding="utf8") as file:
|
||||
processed = Processed(p, [], p.seed, "")
|
||||
file.write(processed.infotext(p, 0))
|
||||
|
||||
|
@ -610,7 +624,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||
if p.n_iter > 1:
|
||||
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
|
||||
|
||||
with devices.autocast():
|
||||
with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
|
||||
samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, prompts=prompts)
|
||||
|
||||
x_samples_ddim = [decode_first_stage(p.sd_model, samples_ddim[i:i+1].to(dtype=devices.dtype_vae))[0].cpu() for i in range(samples_ddim.size(0))]
|
||||
|
@ -645,6 +659,11 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||
|
||||
image = Image.fromarray(x_sample)
|
||||
|
||||
if p.scripts is not None:
|
||||
pp = scripts.PostprocessImageArgs(image)
|
||||
p.scripts.postprocess_image(p, pp)
|
||||
image = pp.image
|
||||
|
||||
if p.color_corrections is not None and i < len(p.color_corrections):
|
||||
if opts.save and not p.do_not_save_samples and opts.save_images_before_color_correction:
|
||||
image_without_cc = apply_overlay(image, p.paste_to, i, p.overlay_images)
|
||||
|
@ -884,12 +903,13 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||
class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
||||
sampler = None
|
||||
|
||||
def __init__(self, init_images: list = None, resize_mode: int = 0, denoising_strength: float = 0.75, mask: Any = None, mask_blur: int = 4, inpainting_fill: int = 0, inpaint_full_res: bool = True, inpaint_full_res_padding: int = 0, inpainting_mask_invert: int = 0, initial_noise_multiplier: float = None, **kwargs):
|
||||
def __init__(self, init_images: list = None, resize_mode: int = 0, denoising_strength: float = 0.75, image_cfg_scale: float = None, mask: Any = None, mask_blur: int = 4, inpainting_fill: int = 0, inpaint_full_res: bool = True, inpaint_full_res_padding: int = 0, inpainting_mask_invert: int = 0, initial_noise_multiplier: float = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.init_images = init_images
|
||||
self.resize_mode: int = resize_mode
|
||||
self.denoising_strength: float = denoising_strength
|
||||
self.image_cfg_scale: float = image_cfg_scale if shared.sd_model.cond_stage_key == "edit" else None
|
||||
self.init_latent = None
|
||||
self.image_mask = mask
|
||||
self.latent_mask = None
|
||||
|
|
|
@ -46,7 +46,7 @@ class UpscalerRealESRGAN(Upscaler):
|
|||
scale=info.scale,
|
||||
model_path=info.local_data_path,
|
||||
model=info.model(),
|
||||
half=not cmd_opts.no_half,
|
||||
half=not cmd_opts.no_half and not cmd_opts.upcast_sampling,
|
||||
tile=opts.ESRGAN_tile,
|
||||
tile_pad=opts.ESRGAN_tile_overlap,
|
||||
)
|
||||
|
|
|
@ -1,16 +1,14 @@
|
|||
import os
|
||||
import sys
|
||||
import traceback
|
||||
import importlib.util
|
||||
from types import ModuleType
|
||||
|
||||
|
||||
def load_module(path):
|
||||
with open(path, "r", encoding="utf8") as file:
|
||||
text = file.read()
|
||||
|
||||
compiled = compile(text, path, 'exec')
|
||||
module = ModuleType(os.path.basename(path))
|
||||
exec(compiled, module.__dict__)
|
||||
module_spec = importlib.util.spec_from_file_location(os.path.basename(path), path)
|
||||
module = importlib.util.module_from_spec(module_spec)
|
||||
module_spec.loader.exec_module(module)
|
||||
|
||||
return module
|
||||
|
||||
|
|
|
@ -6,12 +6,16 @@ from collections import namedtuple
|
|||
|
||||
import gradio as gr
|
||||
|
||||
from modules.processing import StableDiffusionProcessing
|
||||
from modules import shared, paths, script_callbacks, extensions, script_loading, scripts_postprocessing
|
||||
|
||||
AlwaysVisible = object()
|
||||
|
||||
|
||||
class PostprocessImageArgs:
|
||||
def __init__(self, image):
|
||||
self.image = image
|
||||
|
||||
|
||||
class Script:
|
||||
filename = None
|
||||
args_from = None
|
||||
|
@ -65,7 +69,7 @@ class Script:
|
|||
args contains all values returned by components from ui()
|
||||
"""
|
||||
|
||||
raise NotImplementedError()
|
||||
pass
|
||||
|
||||
def process(self, p, *args):
|
||||
"""
|
||||
|
@ -100,6 +104,13 @@ class Script:
|
|||
|
||||
pass
|
||||
|
||||
def postprocess_image(self, p, pp: PostprocessImageArgs, *args):
|
||||
"""
|
||||
Called for every image after it has been generated.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
def postprocess(self, p, processed, *args):
|
||||
"""
|
||||
This function is called after processing ends for AlwaysVisible scripts.
|
||||
|
@ -247,11 +258,15 @@ class ScriptRunner:
|
|||
self.infotext_fields = []
|
||||
|
||||
def initialize_scripts(self, is_img2img):
|
||||
from modules import scripts_auto_postprocessing
|
||||
|
||||
self.scripts.clear()
|
||||
self.alwayson_scripts.clear()
|
||||
self.selectable_scripts.clear()
|
||||
|
||||
for script_class, path, basedir, script_module in scripts_data:
|
||||
auto_processing_scripts = scripts_auto_postprocessing.create_auto_preprocessing_script_data()
|
||||
|
||||
for script_class, path, basedir, script_module in auto_processing_scripts + scripts_data:
|
||||
script = script_class()
|
||||
script.filename = path
|
||||
script.is_txt2img = not is_img2img
|
||||
|
@ -330,9 +345,23 @@ class ScriptRunner:
|
|||
outputs=[script.group for script in self.selectable_scripts]
|
||||
)
|
||||
|
||||
self.script_load_ctr = 0
|
||||
def onload_script_visibility(params):
|
||||
title = params.get('Script', None)
|
||||
if title:
|
||||
title_index = self.titles.index(title)
|
||||
visibility = title_index == self.script_load_ctr
|
||||
self.script_load_ctr = (self.script_load_ctr + 1) % len(self.titles)
|
||||
return gr.update(visible=visibility)
|
||||
else:
|
||||
return gr.update(visible=False)
|
||||
|
||||
self.infotext_fields.append( (dropdown, lambda x: gr.update(value=x.get('Script', 'None'))) )
|
||||
self.infotext_fields.extend( [(script.group, onload_script_visibility) for script in self.selectable_scripts] )
|
||||
|
||||
return inputs
|
||||
|
||||
def run(self, p: StableDiffusionProcessing, *args):
|
||||
def run(self, p, *args):
|
||||
script_index = args[0]
|
||||
|
||||
if script_index == 0:
|
||||
|
@ -386,6 +415,15 @@ class ScriptRunner:
|
|||
print(f"Error running postprocess_batch: {script.filename}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
def postprocess_image(self, p, pp: PostprocessImageArgs):
|
||||
for script in self.alwayson_scripts:
|
||||
try:
|
||||
script_args = p.script_args[script.args_from:script.args_to]
|
||||
script.postprocess_image(p, pp, *script_args)
|
||||
except Exception:
|
||||
print(f"Error running postprocess_batch: {script.filename}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
def before_component(self, component, **kwargs):
|
||||
for script in self.scripts:
|
||||
try:
|
||||
|
|
42
modules/scripts_auto_postprocessing.py
Normal file
42
modules/scripts_auto_postprocessing.py
Normal file
|
@ -0,0 +1,42 @@
|
|||
from modules import scripts, scripts_postprocessing, shared
|
||||
|
||||
|
||||
class ScriptPostprocessingForMainUI(scripts.Script):
|
||||
def __init__(self, script_postproc):
|
||||
self.script: scripts_postprocessing.ScriptPostprocessing = script_postproc
|
||||
self.postprocessing_controls = None
|
||||
|
||||
def title(self):
|
||||
return self.script.name
|
||||
|
||||
def show(self, is_img2img):
|
||||
return scripts.AlwaysVisible
|
||||
|
||||
def ui(self, is_img2img):
|
||||
self.postprocessing_controls = self.script.ui()
|
||||
return self.postprocessing_controls.values()
|
||||
|
||||
def postprocess_image(self, p, script_pp, *args):
|
||||
args_dict = {k: v for k, v in zip(self.postprocessing_controls, args)}
|
||||
|
||||
pp = scripts_postprocessing.PostprocessedImage(script_pp.image)
|
||||
pp.info = {}
|
||||
self.script.process(pp, **args_dict)
|
||||
p.extra_generation_params.update(pp.info)
|
||||
script_pp.image = pp.image
|
||||
|
||||
|
||||
def create_auto_preprocessing_script_data():
|
||||
from modules import scripts
|
||||
|
||||
res = []
|
||||
|
||||
for name in shared.opts.postprocessing_enable_in_main_ui:
|
||||
script = next(iter([x for x in scripts.postprocessing_scripts_data if x.script_class.name == name]), None)
|
||||
if script is None:
|
||||
continue
|
||||
|
||||
constructor = lambda s=script: ScriptPostprocessingForMainUI(s.script_class())
|
||||
res.append(scripts.ScriptClassData(script_class=constructor, path=script.path, basedir=script.basedir, module=script.module))
|
||||
|
||||
return res
|
|
@ -46,6 +46,8 @@ class ScriptPostprocessing:
|
|||
pass
|
||||
|
||||
|
||||
|
||||
|
||||
def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
|
||||
try:
|
||||
res = func(*args, **kwargs)
|
||||
|
@ -68,6 +70,9 @@ class ScriptPostprocessingRunner:
|
|||
script: ScriptPostprocessing = script_class()
|
||||
script.filename = path
|
||||
|
||||
if script.name == "Simple Upscale":
|
||||
continue
|
||||
|
||||
self.scripts.append(script)
|
||||
|
||||
def create_script_ui(self, script, inputs):
|
||||
|
@ -87,12 +92,11 @@ class ScriptPostprocessingRunner:
|
|||
import modules.scripts
|
||||
self.initialize_scripts(modules.scripts.postprocessing_scripts_data)
|
||||
|
||||
scripts_order = [x.lower().strip() for x in shared.opts.postprocessing_scipts_order.split(",")]
|
||||
scripts_order = shared.opts.postprocessing_operation_order
|
||||
|
||||
def script_score(name):
|
||||
name = name.lower()
|
||||
for i, possible_match in enumerate(scripts_order):
|
||||
if possible_match in name:
|
||||
if possible_match == name:
|
||||
return i
|
||||
|
||||
return len(self.scripts)
|
||||
|
@ -145,3 +149,4 @@ class ScriptPostprocessingRunner:
|
|||
def image_changed(self):
|
||||
for script in self.scripts_in_preferred_order():
|
||||
script.image_changed()
|
||||
|
||||
|
|
|
@ -20,8 +20,9 @@ class DisableInitialization:
|
|||
```
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, disable_clip=True):
|
||||
self.replaced = []
|
||||
self.disable_clip = disable_clip
|
||||
|
||||
def replace(self, obj, field, func):
|
||||
original = getattr(obj, field, None)
|
||||
|
@ -75,12 +76,14 @@ class DisableInitialization:
|
|||
self.replace(torch.nn.init, 'kaiming_uniform_', do_nothing)
|
||||
self.replace(torch.nn.init, '_no_grad_normal_', do_nothing)
|
||||
self.replace(torch.nn.init, '_no_grad_uniform_', do_nothing)
|
||||
self.create_model_and_transforms = self.replace(open_clip, 'create_model_and_transforms', create_model_and_transforms_without_pretrained)
|
||||
self.CLIPTextModel_from_pretrained = self.replace(ldm.modules.encoders.modules.CLIPTextModel, 'from_pretrained', CLIPTextModel_from_pretrained)
|
||||
self.transformers_modeling_utils_load_pretrained_model = self.replace(transformers.modeling_utils.PreTrainedModel, '_load_pretrained_model', transformers_modeling_utils_load_pretrained_model)
|
||||
self.transformers_tokenization_utils_base_cached_file = self.replace(transformers.tokenization_utils_base, 'cached_file', transformers_tokenization_utils_base_cached_file)
|
||||
self.transformers_configuration_utils_cached_file = self.replace(transformers.configuration_utils, 'cached_file', transformers_configuration_utils_cached_file)
|
||||
self.transformers_utils_hub_get_from_cache = self.replace(transformers.utils.hub, 'get_from_cache', transformers_utils_hub_get_from_cache)
|
||||
|
||||
if self.disable_clip:
|
||||
self.create_model_and_transforms = self.replace(open_clip, 'create_model_and_transforms', create_model_and_transforms_without_pretrained)
|
||||
self.CLIPTextModel_from_pretrained = self.replace(ldm.modules.encoders.modules.CLIPTextModel, 'from_pretrained', CLIPTextModel_from_pretrained)
|
||||
self.transformers_modeling_utils_load_pretrained_model = self.replace(transformers.modeling_utils.PreTrainedModel, '_load_pretrained_model', transformers_modeling_utils_load_pretrained_model)
|
||||
self.transformers_tokenization_utils_base_cached_file = self.replace(transformers.tokenization_utils_base, 'cached_file', transformers_tokenization_utils_base_cached_file)
|
||||
self.transformers_configuration_utils_cached_file = self.replace(transformers.configuration_utils, 'cached_file', transformers_configuration_utils_cached_file)
|
||||
self.transformers_utils_hub_get_from_cache = self.replace(transformers.utils.hub, 'get_from_cache', transformers_utils_hub_get_from_cache)
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
for obj, field, original in self.replaced:
|
||||
|
|
|
@ -131,6 +131,8 @@ class StableDiffusionModelHijack:
|
|||
m.cond_stage_model.wrapped.model.token_embedding = m.cond_stage_model.wrapped.model.token_embedding.wrapped
|
||||
m.cond_stage_model = m.cond_stage_model.wrapped
|
||||
|
||||
undo_optimizations()
|
||||
|
||||
self.apply_circular(False)
|
||||
self.layers = None
|
||||
self.clip = None
|
||||
|
@ -171,7 +173,7 @@ class EmbeddingsWithFixes(torch.nn.Module):
|
|||
vecs = []
|
||||
for fixes, tensor in zip(batch_fixes, inputs_embeds):
|
||||
for offset, embedding in fixes:
|
||||
emb = embedding.vec
|
||||
emb = devices.cond_cast_unet(embedding.vec)
|
||||
emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0])
|
||||
tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]])
|
||||
|
||||
|
|
|
@ -96,15 +96,6 @@ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=F
|
|||
return x_prev, pred_x0, e_t
|
||||
|
||||
|
||||
def should_hijack_inpainting(checkpoint_info):
|
||||
from modules import sd_models
|
||||
|
||||
ckpt_basename = os.path.basename(checkpoint_info.filename).lower()
|
||||
cfg_basename = os.path.basename(sd_models.find_checkpoint_config(checkpoint_info)).lower()
|
||||
|
||||
return "inpainting" in ckpt_basename and not "inpainting" in cfg_basename
|
||||
|
||||
|
||||
def do_inpainting_hijack():
|
||||
# p_sample_plms is needed because PLMS can't work with dicts as conditionings
|
||||
|
||||
|
|
13
modules/sd_hijack_ip2p.py
Normal file
13
modules/sd_hijack_ip2p.py
Normal file
|
@ -0,0 +1,13 @@
|
|||
import collections
|
||||
import os.path
|
||||
import sys
|
||||
import gc
|
||||
import time
|
||||
|
||||
def should_hijack_ip2p(checkpoint_info):
|
||||
from modules import sd_models_config
|
||||
|
||||
ckpt_basename = os.path.basename(checkpoint_info.filename).lower()
|
||||
cfg_basename = os.path.basename(sd_models_config.find_checkpoint_config_near_filename(checkpoint_info)).lower()
|
||||
|
||||
return "pix2pix" in ckpt_basename and not "pix2pix" in cfg_basename
|
|
@ -9,7 +9,7 @@ from torch import einsum
|
|||
from ldm.util import default
|
||||
from einops import rearrange
|
||||
|
||||
from modules import shared, errors
|
||||
from modules import shared, errors, devices
|
||||
from modules.hypernetworks import hypernetwork
|
||||
|
||||
from .sub_quadratic_attention import efficient_dot_product_attention
|
||||
|
@ -52,18 +52,25 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None):
|
|||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
|
||||
del q_in, k_in, v_in
|
||||
|
||||
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
|
||||
for i in range(0, q.shape[0], 2):
|
||||
end = i + 2
|
||||
s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
|
||||
s1 *= self.scale
|
||||
dtype = q.dtype
|
||||
if shared.opts.upcast_attn:
|
||||
q, k, v = q.float(), k.float(), v.float()
|
||||
|
||||
s2 = s1.softmax(dim=-1)
|
||||
del s1
|
||||
with devices.without_autocast(disable=not shared.opts.upcast_attn):
|
||||
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||
for i in range(0, q.shape[0], 2):
|
||||
end = i + 2
|
||||
s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
|
||||
s1 *= self.scale
|
||||
|
||||
s2 = s1.softmax(dim=-1)
|
||||
del s1
|
||||
|
||||
r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
|
||||
del s2
|
||||
del q, k, v
|
||||
|
||||
r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
|
||||
del s2
|
||||
del q, k, v
|
||||
r1 = r1.to(dtype)
|
||||
|
||||
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
|
||||
del r1
|
||||
|
@ -82,45 +89,52 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
|
|||
k_in = self.to_k(context_k)
|
||||
v_in = self.to_v(context_v)
|
||||
|
||||
k_in *= self.scale
|
||||
dtype = q_in.dtype
|
||||
if shared.opts.upcast_attn:
|
||||
q_in, k_in, v_in = q_in.float(), k_in.float(), v_in if v_in.device.type == 'mps' else v_in.float()
|
||||
|
||||
del context, x
|
||||
with devices.without_autocast(disable=not shared.opts.upcast_attn):
|
||||
k_in = k_in * self.scale
|
||||
|
||||
del context, x
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
|
||||
del q_in, k_in, v_in
|
||||
|
||||
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||
|
||||
mem_free_total = get_available_vram()
|
||||
|
||||
gb = 1024 ** 3
|
||||
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
|
||||
modifier = 3 if q.element_size() == 2 else 2.5
|
||||
mem_required = tensor_size * modifier
|
||||
steps = 1
|
||||
|
||||
if mem_required > mem_free_total:
|
||||
steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
|
||||
# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
|
||||
# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
|
||||
|
||||
if steps > 64:
|
||||
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
|
||||
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
|
||||
f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
|
||||
|
||||
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
||||
for i in range(0, q.shape[1], slice_size):
|
||||
end = i + slice_size
|
||||
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
|
||||
|
||||
s2 = s1.softmax(dim=-1, dtype=q.dtype)
|
||||
del s1
|
||||
|
||||
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
|
||||
del s2
|
||||
|
||||
del q, k, v
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
|
||||
del q_in, k_in, v_in
|
||||
|
||||
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||
|
||||
mem_free_total = get_available_vram()
|
||||
|
||||
gb = 1024 ** 3
|
||||
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
|
||||
modifier = 3 if q.element_size() == 2 else 2.5
|
||||
mem_required = tensor_size * modifier
|
||||
steps = 1
|
||||
|
||||
if mem_required > mem_free_total:
|
||||
steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
|
||||
# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
|
||||
# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
|
||||
|
||||
if steps > 64:
|
||||
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
|
||||
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
|
||||
f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
|
||||
|
||||
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
||||
for i in range(0, q.shape[1], slice_size):
|
||||
end = i + slice_size
|
||||
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
|
||||
|
||||
s2 = s1.softmax(dim=-1, dtype=q.dtype)
|
||||
del s1
|
||||
|
||||
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
|
||||
del s2
|
||||
|
||||
del q, k, v
|
||||
r1 = r1.to(dtype)
|
||||
|
||||
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
|
||||
del r1
|
||||
|
@ -204,12 +218,20 @@ def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None):
|
|||
context = default(context, x)
|
||||
|
||||
context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
|
||||
k = self.to_k(context_k) * self.scale
|
||||
k = self.to_k(context_k)
|
||||
v = self.to_v(context_v)
|
||||
del context, context_k, context_v, x
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
||||
r = einsum_op(q, k, v)
|
||||
dtype = q.dtype
|
||||
if shared.opts.upcast_attn:
|
||||
q, k, v = q.float(), k.float(), v if v.device.type == 'mps' else v.float()
|
||||
|
||||
with devices.without_autocast(disable=not shared.opts.upcast_attn):
|
||||
k = k * self.scale
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
||||
r = einsum_op(q, k, v)
|
||||
r = r.to(dtype)
|
||||
return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h))
|
||||
|
||||
# -- End of code from https://github.com/invoke-ai/InvokeAI --
|
||||
|
@ -234,8 +256,14 @@ def sub_quad_attention_forward(self, x, context=None, mask=None):
|
|||
k = k.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
|
||||
v = v.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
|
||||
|
||||
dtype = q.dtype
|
||||
if shared.opts.upcast_attn:
|
||||
q, k = q.float(), k.float()
|
||||
|
||||
x = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training)
|
||||
|
||||
x = x.to(dtype)
|
||||
|
||||
x = x.unflatten(0, (-1, h)).transpose(1,2).flatten(start_dim=2)
|
||||
|
||||
out_proj, dropout = self.to_out
|
||||
|
@ -268,15 +296,16 @@ def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_
|
|||
query_chunk_size = q_tokens
|
||||
kv_chunk_size = k_tokens
|
||||
|
||||
return efficient_dot_product_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
query_chunk_size=q_chunk_size,
|
||||
kv_chunk_size=kv_chunk_size,
|
||||
kv_chunk_size_min = kv_chunk_size_min,
|
||||
use_checkpoint=use_checkpoint,
|
||||
)
|
||||
with devices.without_autocast(disable=q.dtype == v.dtype):
|
||||
return efficient_dot_product_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
query_chunk_size=q_chunk_size,
|
||||
kv_chunk_size=kv_chunk_size,
|
||||
kv_chunk_size_min = kv_chunk_size_min,
|
||||
use_checkpoint=use_checkpoint,
|
||||
)
|
||||
|
||||
|
||||
def get_xformers_flash_attention_op(q, k, v):
|
||||
|
@ -306,8 +335,14 @@ def xformers_attention_forward(self, x, context=None, mask=None):
|
|||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in))
|
||||
del q_in, k_in, v_in
|
||||
|
||||
dtype = q.dtype
|
||||
if shared.opts.upcast_attn:
|
||||
q, k = q.float(), k.float()
|
||||
|
||||
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=get_xformers_flash_attention_op(q, k, v))
|
||||
|
||||
out = out.to(dtype)
|
||||
|
||||
out = rearrange(out, 'b n h d -> b n (h d)', h=h)
|
||||
return self.to_out(out)
|
||||
|
||||
|
@ -378,10 +413,14 @@ def xformers_attnblock_forward(self, x):
|
|||
v = self.v(h_)
|
||||
b, c, h, w = q.shape
|
||||
q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v))
|
||||
dtype = q.dtype
|
||||
if shared.opts.upcast_attn:
|
||||
q, k = q.float(), k.float()
|
||||
q = q.contiguous()
|
||||
k = k.contiguous()
|
||||
v = v.contiguous()
|
||||
out = xformers.ops.memory_efficient_attention(q, k, v, op=get_xformers_flash_attention_op(q, k, v))
|
||||
out = out.to(dtype)
|
||||
out = rearrange(out, 'b (h w) c -> b c h w', h=h)
|
||||
out = self.proj_out(out)
|
||||
return x + out
|
||||
|
|
|
@ -1,4 +1,8 @@
|
|||
import torch
|
||||
from packaging import version
|
||||
|
||||
from modules import devices
|
||||
from modules.sd_hijack_utils import CondFunc
|
||||
|
||||
|
||||
class TorchHijackForUnet:
|
||||
|
@ -28,3 +32,37 @@ class TorchHijackForUnet:
|
|||
|
||||
|
||||
th = TorchHijackForUnet()
|
||||
|
||||
|
||||
# Below are monkey patches to enable upcasting a float16 UNet for float32 sampling
|
||||
def apply_model(orig_func, self, x_noisy, t, cond, **kwargs):
|
||||
|
||||
if isinstance(cond, dict):
|
||||
for y in cond.keys():
|
||||
cond[y] = [x.to(devices.dtype_unet) if isinstance(x, torch.Tensor) else x for x in cond[y]]
|
||||
|
||||
with devices.autocast():
|
||||
return orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs).float()
|
||||
|
||||
class GELUHijack(torch.nn.GELU, torch.nn.Module):
|
||||
def __init__(self, *args, **kwargs):
|
||||
torch.nn.GELU.__init__(self, *args, **kwargs)
|
||||
def forward(self, x):
|
||||
if devices.unet_needs_upcast:
|
||||
return torch.nn.GELU.forward(self.float(), x.float()).to(devices.dtype_unet)
|
||||
else:
|
||||
return torch.nn.GELU.forward(self, x)
|
||||
|
||||
unet_needs_upcast = lambda *args, **kwargs: devices.unet_needs_upcast
|
||||
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model, unet_needs_upcast)
|
||||
CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast)
|
||||
if version.parse(torch.__version__) <= version.parse("1.13.1"):
|
||||
CondFunc('ldm.modules.diffusionmodules.util.GroupNorm32.forward', lambda orig_func, self, *args, **kwargs: orig_func(self.float(), *args, **kwargs), unet_needs_upcast)
|
||||
CondFunc('ldm.modules.attention.GEGLU.forward', lambda orig_func, self, x: orig_func(self.float(), x.float()).to(devices.dtype_unet), unet_needs_upcast)
|
||||
CondFunc('open_clip.transformer.ResidualAttentionBlock.__init__', lambda orig_func, *args, **kwargs: kwargs.update({'act_layer': GELUHijack}) and False or orig_func(*args, **kwargs), lambda _, *args, **kwargs: kwargs.get('act_layer') is None or kwargs['act_layer'] == torch.nn.GELU)
|
||||
|
||||
first_stage_cond = lambda _, self, *args, **kwargs: devices.unet_needs_upcast and self.model.diffusion_model.dtype == torch.float16
|
||||
first_stage_sub = lambda orig_func, self, x, **kwargs: orig_func(self, x.to(devices.dtype_vae), **kwargs)
|
||||
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond)
|
||||
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond)
|
||||
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.get_first_stage_encoding', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).float(), first_stage_cond)
|
||||
|
|
28
modules/sd_hijack_utils.py
Normal file
28
modules/sd_hijack_utils.py
Normal file
|
@ -0,0 +1,28 @@
|
|||
import importlib
|
||||
|
||||
class CondFunc:
|
||||
def __new__(cls, orig_func, sub_func, cond_func):
|
||||
self = super(CondFunc, cls).__new__(cls)
|
||||
if isinstance(orig_func, str):
|
||||
func_path = orig_func.split('.')
|
||||
for i in range(len(func_path)-1, -1, -1):
|
||||
try:
|
||||
resolved_obj = importlib.import_module('.'.join(func_path[:i]))
|
||||
break
|
||||
except ImportError:
|
||||
pass
|
||||
for attr_name in func_path[i:-1]:
|
||||
resolved_obj = getattr(resolved_obj, attr_name)
|
||||
orig_func = getattr(resolved_obj, func_path[-1])
|
||||
setattr(resolved_obj, func_path[-1], lambda *args, **kwargs: self(*args, **kwargs))
|
||||
self.__init__(orig_func, sub_func, cond_func)
|
||||
return lambda *args, **kwargs: self(*args, **kwargs)
|
||||
def __init__(self, orig_func, sub_func, cond_func):
|
||||
self.__orig_func = orig_func
|
||||
self.__sub_func = sub_func
|
||||
self.__cond_func = cond_func
|
||||
def __call__(self, *args, **kwargs):
|
||||
if not self.__cond_func or self.__cond_func(self.__orig_func, *args, **kwargs):
|
||||
return self.__sub_func(self.__orig_func, *args, **kwargs)
|
||||
else:
|
||||
return self.__orig_func(*args, **kwargs)
|
|
@ -2,8 +2,6 @@ import collections
|
|||
import os.path
|
||||
import sys
|
||||
import gc
|
||||
import time
|
||||
from collections import namedtuple
|
||||
import torch
|
||||
import re
|
||||
import safetensors.torch
|
||||
|
@ -14,12 +12,13 @@ import ldm.modules.midas as midas
|
|||
|
||||
from ldm.util import instantiate_from_config
|
||||
|
||||
from modules import shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes
|
||||
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config
|
||||
from modules.paths import models_path
|
||||
from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting
|
||||
from modules.sd_hijack_inpainting import do_inpainting_hijack
|
||||
from modules.timer import Timer
|
||||
|
||||
model_dir = "Stable-diffusion"
|
||||
model_path = os.path.abspath(os.path.join(models_path, model_dir))
|
||||
model_path = os.path.abspath(os.path.join(paths.models_path, model_dir))
|
||||
|
||||
checkpoints_list = {}
|
||||
checkpoint_alisases = {}
|
||||
|
@ -42,6 +41,7 @@ class CheckpointInfo:
|
|||
name = name[1:]
|
||||
|
||||
self.name = name
|
||||
self.name_for_extra = os.path.splitext(os.path.basename(filename))[0]
|
||||
self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]
|
||||
self.hash = model_hash(filename)
|
||||
|
||||
|
@ -59,13 +59,17 @@ class CheckpointInfo:
|
|||
|
||||
def calculate_shorthash(self):
|
||||
self.sha256 = hashes.sha256(self.filename, "checkpoint/" + self.name)
|
||||
if self.sha256 is None:
|
||||
return
|
||||
|
||||
self.shorthash = self.sha256[0:10]
|
||||
|
||||
if self.shorthash not in self.ids:
|
||||
self.ids += [self.shorthash, self.sha256]
|
||||
self.register()
|
||||
self.ids += [self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]']
|
||||
|
||||
checkpoints_list.pop(self.title)
|
||||
self.title = f'{self.name} [{self.shorthash}]'
|
||||
self.register()
|
||||
|
||||
return self.shorthash
|
||||
|
||||
|
@ -98,17 +102,6 @@ def checkpoint_tiles():
|
|||
return sorted([x.title for x in checkpoints_list.values()], key=alphanumeric_key)
|
||||
|
||||
|
||||
def find_checkpoint_config(info):
|
||||
if info is None:
|
||||
return shared.cmd_opts.config
|
||||
|
||||
config = os.path.splitext(info.filename)[0] + ".yaml"
|
||||
if os.path.exists(config):
|
||||
return config
|
||||
|
||||
return shared.cmd_opts.config
|
||||
|
||||
|
||||
def list_models():
|
||||
checkpoints_list.clear()
|
||||
checkpoint_alisases.clear()
|
||||
|
@ -169,7 +162,7 @@ def select_checkpoint():
|
|||
print(f" - directory {model_path}", file=sys.stderr)
|
||||
if shared.cmd_opts.ckpt_dir is not None:
|
||||
print(f" - directory {os.path.abspath(shared.cmd_opts.ckpt_dir)}", file=sys.stderr)
|
||||
print("Can't run without a checkpoint. Find and place a .ckpt file into any of those locations. The program will exit.", file=sys.stderr)
|
||||
print("Can't run without a checkpoint. Find and place a .ckpt or .safetensors file into any of those locations. The program will exit.", file=sys.stderr)
|
||||
exit(1)
|
||||
|
||||
checkpoint_info = next(iter(checkpoints_list.values()))
|
||||
|
@ -214,9 +207,7 @@ def get_state_dict_from_checkpoint(pl_sd):
|
|||
def read_state_dict(checkpoint_file, print_global_state=False, map_location=None):
|
||||
_, extension = os.path.splitext(checkpoint_file)
|
||||
if extension.lower() == ".safetensors":
|
||||
device = map_location or shared.weight_load_location
|
||||
if device is None:
|
||||
device = devices.get_cuda_device_string() if torch.cuda.is_available() else "cpu"
|
||||
device = map_location or shared.weight_load_location or devices.get_optimal_device()
|
||||
pl_sd = safetensors.torch.load_file(checkpoint_file, device=device)
|
||||
else:
|
||||
pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location)
|
||||
|
@ -228,52 +219,72 @@ def read_state_dict(checkpoint_file, print_global_state=False, map_location=None
|
|||
return sd
|
||||
|
||||
|
||||
def load_model_weights(model, checkpoint_info: CheckpointInfo):
|
||||
title = checkpoint_info.title
|
||||
def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer):
|
||||
sd_model_hash = checkpoint_info.calculate_shorthash()
|
||||
if checkpoint_info.title != title:
|
||||
shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title
|
||||
timer.record("calculate hash")
|
||||
|
||||
cache_enabled = shared.opts.sd_checkpoint_cache > 0
|
||||
|
||||
if cache_enabled and checkpoint_info in checkpoints_loaded:
|
||||
if checkpoint_info in checkpoints_loaded:
|
||||
# use checkpoint cache
|
||||
print(f"Loading weights [{sd_model_hash}] from cache")
|
||||
model.load_state_dict(checkpoints_loaded[checkpoint_info])
|
||||
else:
|
||||
# load from file
|
||||
print(f"Loading weights [{sd_model_hash}] from {checkpoint_info.filename}")
|
||||
return checkpoints_loaded[checkpoint_info]
|
||||
|
||||
sd = read_state_dict(checkpoint_info.filename)
|
||||
model.load_state_dict(sd, strict=False)
|
||||
del sd
|
||||
|
||||
if cache_enabled:
|
||||
# cache newly loaded model
|
||||
checkpoints_loaded[checkpoint_info] = model.state_dict().copy()
|
||||
print(f"Loading weights [{sd_model_hash}] from {checkpoint_info.filename}")
|
||||
res = read_state_dict(checkpoint_info.filename)
|
||||
timer.record("load weights from disk")
|
||||
|
||||
if shared.cmd_opts.opt_channelslast:
|
||||
model.to(memory_format=torch.channels_last)
|
||||
return res
|
||||
|
||||
if not shared.cmd_opts.no_half:
|
||||
vae = model.first_stage_model
|
||||
|
||||
# with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16
|
||||
if shared.cmd_opts.no_half_vae:
|
||||
model.first_stage_model = None
|
||||
def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer):
|
||||
sd_model_hash = checkpoint_info.calculate_shorthash()
|
||||
timer.record("calculate hash")
|
||||
|
||||
model.half()
|
||||
model.first_stage_model = vae
|
||||
shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title
|
||||
|
||||
devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16
|
||||
devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16
|
||||
if state_dict is None:
|
||||
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
|
||||
|
||||
model.first_stage_model.to(devices.dtype_vae)
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
del state_dict
|
||||
timer.record("apply weights to model")
|
||||
|
||||
if shared.opts.sd_checkpoint_cache > 0:
|
||||
# cache newly loaded model
|
||||
checkpoints_loaded[checkpoint_info] = model.state_dict().copy()
|
||||
|
||||
if shared.cmd_opts.opt_channelslast:
|
||||
model.to(memory_format=torch.channels_last)
|
||||
timer.record("apply channels_last")
|
||||
|
||||
if not shared.cmd_opts.no_half:
|
||||
vae = model.first_stage_model
|
||||
depth_model = getattr(model, 'depth_model', None)
|
||||
|
||||
# with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16
|
||||
if shared.cmd_opts.no_half_vae:
|
||||
model.first_stage_model = None
|
||||
# with --upcast-sampling, don't convert the depth model weights to float16
|
||||
if shared.cmd_opts.upcast_sampling and depth_model:
|
||||
model.depth_model = None
|
||||
|
||||
model.half()
|
||||
model.first_stage_model = vae
|
||||
if depth_model:
|
||||
model.depth_model = depth_model
|
||||
|
||||
timer.record("apply half()")
|
||||
|
||||
devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16
|
||||
devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16
|
||||
devices.dtype_unet = model.model.diffusion_model.dtype
|
||||
devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16
|
||||
|
||||
model.first_stage_model.to(devices.dtype_vae)
|
||||
timer.record("apply dtype to VAE")
|
||||
|
||||
# clean up cache if limit is reached
|
||||
if cache_enabled:
|
||||
while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache + 1: # we need to count the current model
|
||||
checkpoints_loaded.popitem(last=False) # LRU
|
||||
while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:
|
||||
checkpoints_loaded.popitem(last=False)
|
||||
|
||||
model.sd_model_hash = sd_model_hash
|
||||
model.sd_model_checkpoint = checkpoint_info.filename
|
||||
|
@ -286,6 +297,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo):
|
|||
sd_vae.clear_loaded_vae()
|
||||
vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename)
|
||||
sd_vae.load_vae(model, vae_file, vae_source)
|
||||
timer.record("load VAE")
|
||||
|
||||
|
||||
def enable_midas_autodownload():
|
||||
|
@ -298,7 +310,7 @@ def enable_midas_autodownload():
|
|||
location automatically.
|
||||
"""
|
||||
|
||||
midas_path = os.path.join(models_path, 'midas')
|
||||
midas_path = os.path.join(paths.models_path, 'midas')
|
||||
|
||||
# stable-diffusion-stability-ai hard-codes the midas model path to
|
||||
# a location that differs from where other scripts using this model look.
|
||||
|
@ -331,24 +343,23 @@ def enable_midas_autodownload():
|
|||
midas.api.load_model = load_model_wrapper
|
||||
|
||||
|
||||
class Timer:
|
||||
def __init__(self):
|
||||
self.start = time.time()
|
||||
def repair_config(sd_config):
|
||||
|
||||
def elapsed(self):
|
||||
end = time.time()
|
||||
res = end - self.start
|
||||
self.start = end
|
||||
return res
|
||||
if not hasattr(sd_config.model.params, "use_ema"):
|
||||
sd_config.model.params.use_ema = False
|
||||
|
||||
if shared.cmd_opts.no_half:
|
||||
sd_config.model.params.unet_config.params.use_fp16 = False
|
||||
elif shared.cmd_opts.upcast_sampling:
|
||||
sd_config.model.params.unet_config.params.use_fp16 = True
|
||||
|
||||
|
||||
def load_model(checkpoint_info=None):
|
||||
sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight'
|
||||
sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight'
|
||||
|
||||
def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_to_load_state_dict=None):
|
||||
from modules import lowvram, sd_hijack
|
||||
checkpoint_info = checkpoint_info or select_checkpoint()
|
||||
checkpoint_config = find_checkpoint_config(checkpoint_info)
|
||||
|
||||
if checkpoint_config != shared.cmd_opts.config:
|
||||
print(f"Loading config from: {checkpoint_config}")
|
||||
|
||||
if shared.sd_model:
|
||||
sd_hijack.model_hijack.undo_hijack(shared.sd_model)
|
||||
|
@ -356,29 +367,30 @@ def load_model(checkpoint_info=None):
|
|||
gc.collect()
|
||||
devices.torch_gc()
|
||||
|
||||
sd_config = OmegaConf.load(checkpoint_config)
|
||||
|
||||
if should_hijack_inpainting(checkpoint_info):
|
||||
# Hardcoded config for now...
|
||||
sd_config.model.target = "ldm.models.diffusion.ddpm.LatentInpaintDiffusion"
|
||||
sd_config.model.params.conditioning_key = "hybrid"
|
||||
sd_config.model.params.unet_config.params.in_channels = 9
|
||||
sd_config.model.params.finetune_keys = None
|
||||
|
||||
if not hasattr(sd_config.model.params, "use_ema"):
|
||||
sd_config.model.params.use_ema = False
|
||||
|
||||
do_inpainting_hijack()
|
||||
|
||||
if shared.cmd_opts.no_half:
|
||||
sd_config.model.params.unet_config.params.use_fp16 = False
|
||||
|
||||
timer = Timer()
|
||||
|
||||
sd_model = None
|
||||
if already_loaded_state_dict is not None:
|
||||
state_dict = already_loaded_state_dict
|
||||
else:
|
||||
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
|
||||
|
||||
checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
|
||||
clip_is_included_into_sd = sd1_clip_weight in state_dict or sd2_clip_weight in state_dict
|
||||
|
||||
timer.record("find config")
|
||||
|
||||
sd_config = OmegaConf.load(checkpoint_config)
|
||||
repair_config(sd_config)
|
||||
|
||||
timer.record("load config")
|
||||
|
||||
print(f"Creating model from config: {checkpoint_config}")
|
||||
|
||||
sd_model = None
|
||||
try:
|
||||
with sd_disable_initialization.DisableInitialization():
|
||||
with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd):
|
||||
sd_model = instantiate_from_config(sd_config.model)
|
||||
except Exception as e:
|
||||
pass
|
||||
|
@ -387,29 +399,35 @@ def load_model(checkpoint_info=None):
|
|||
print('Failed to create model quickly; will retry using slow method.', file=sys.stderr)
|
||||
sd_model = instantiate_from_config(sd_config.model)
|
||||
|
||||
elapsed_create = timer.elapsed()
|
||||
sd_model.used_config = checkpoint_config
|
||||
|
||||
load_model_weights(sd_model, checkpoint_info)
|
||||
timer.record("create model")
|
||||
|
||||
elapsed_load_weights = timer.elapsed()
|
||||
load_model_weights(sd_model, checkpoint_info, state_dict, timer)
|
||||
|
||||
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
||||
lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram)
|
||||
else:
|
||||
sd_model.to(shared.device)
|
||||
|
||||
timer.record("move model to device")
|
||||
|
||||
sd_hijack.model_hijack.hijack(sd_model)
|
||||
|
||||
timer.record("hijack")
|
||||
|
||||
sd_model.eval()
|
||||
shared.sd_model = sd_model
|
||||
|
||||
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model
|
||||
|
||||
timer.record("load textual inversion embeddings")
|
||||
|
||||
script_callbacks.model_loaded_callback(sd_model)
|
||||
|
||||
elapsed_the_rest = timer.elapsed()
|
||||
timer.record("scripts callbacks")
|
||||
|
||||
print(f"Model loaded in {elapsed_create + elapsed_load_weights + elapsed_the_rest:.1f}s ({elapsed_create:.1f}s create model, {elapsed_load_weights:.1f}s load weights).")
|
||||
print(f"Model loaded in {timer.summary()}.")
|
||||
|
||||
return sd_model
|
||||
|
||||
|
@ -420,6 +438,7 @@ def reload_model_weights(sd_model=None, info=None):
|
|||
|
||||
if not sd_model:
|
||||
sd_model = shared.sd_model
|
||||
|
||||
if sd_model is None: # previous model load failed
|
||||
current_checkpoint_info = None
|
||||
else:
|
||||
|
@ -427,38 +446,44 @@ def reload_model_weights(sd_model=None, info=None):
|
|||
if sd_model.sd_model_checkpoint == checkpoint_info.filename:
|
||||
return
|
||||
|
||||
checkpoint_config = find_checkpoint_config(current_checkpoint_info)
|
||||
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
||||
lowvram.send_everything_to_cpu()
|
||||
else:
|
||||
sd_model.to(devices.cpu)
|
||||
|
||||
if current_checkpoint_info is None or checkpoint_config != find_checkpoint_config(checkpoint_info) or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info):
|
||||
del sd_model
|
||||
checkpoints_loaded.clear()
|
||||
load_model(checkpoint_info)
|
||||
return shared.sd_model
|
||||
|
||||
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
||||
lowvram.send_everything_to_cpu()
|
||||
else:
|
||||
sd_model.to(devices.cpu)
|
||||
|
||||
sd_hijack.model_hijack.undo_hijack(sd_model)
|
||||
sd_hijack.model_hijack.undo_hijack(sd_model)
|
||||
|
||||
timer = Timer()
|
||||
|
||||
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
|
||||
|
||||
checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
|
||||
|
||||
timer.record("find config")
|
||||
|
||||
if sd_model is None or checkpoint_config != sd_model.used_config:
|
||||
del sd_model
|
||||
checkpoints_loaded.clear()
|
||||
load_model(checkpoint_info, already_loaded_state_dict=state_dict, time_taken_to_load_state_dict=timer.records["load weights from disk"])
|
||||
return shared.sd_model
|
||||
|
||||
try:
|
||||
load_model_weights(sd_model, checkpoint_info)
|
||||
load_model_weights(sd_model, checkpoint_info, state_dict, timer)
|
||||
except Exception as e:
|
||||
print("Failed to load checkpoint, restoring previous")
|
||||
load_model_weights(sd_model, current_checkpoint_info)
|
||||
load_model_weights(sd_model, current_checkpoint_info, None, timer)
|
||||
raise
|
||||
finally:
|
||||
sd_hijack.model_hijack.hijack(sd_model)
|
||||
timer.record("hijack")
|
||||
|
||||
script_callbacks.model_loaded_callback(sd_model)
|
||||
timer.record("script callbacks")
|
||||
|
||||
if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
|
||||
sd_model.to(devices.device)
|
||||
timer.record("move model to device")
|
||||
|
||||
elapsed = timer.elapsed()
|
||||
|
||||
print(f"Weights loaded in {elapsed:.1f}s.")
|
||||
print(f"Weights loaded in {timer.summary()}.")
|
||||
|
||||
return sd_model
|
||||
|
|
112
modules/sd_models_config.py
Normal file
112
modules/sd_models_config.py
Normal file
|
@ -0,0 +1,112 @@
|
|||
import re
|
||||
import os
|
||||
|
||||
import torch
|
||||
|
||||
from modules import shared, paths, sd_disable_initialization
|
||||
|
||||
sd_configs_path = shared.sd_configs_path
|
||||
sd_repo_configs_path = os.path.join(paths.paths['Stable Diffusion'], "configs", "stable-diffusion")
|
||||
|
||||
|
||||
config_default = shared.sd_default_config
|
||||
config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference.yaml")
|
||||
config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml")
|
||||
config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml")
|
||||
config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml")
|
||||
config_inpainting = os.path.join(sd_configs_path, "v1-inpainting-inference.yaml")
|
||||
config_instruct_pix2pix = os.path.join(sd_configs_path, "instruct-pix2pix.yaml")
|
||||
config_alt_diffusion = os.path.join(sd_configs_path, "alt-diffusion-inference.yaml")
|
||||
|
||||
|
||||
def is_using_v_parameterization_for_sd2(state_dict):
|
||||
"""
|
||||
Detects whether unet in state_dict is using v-parameterization. Returns True if it is. You're welcome.
|
||||
"""
|
||||
|
||||
import ldm.modules.diffusionmodules.openaimodel
|
||||
from modules import devices
|
||||
|
||||
device = devices.cpu
|
||||
|
||||
with sd_disable_initialization.DisableInitialization():
|
||||
unet = ldm.modules.diffusionmodules.openaimodel.UNetModel(
|
||||
use_checkpoint=True,
|
||||
use_fp16=False,
|
||||
image_size=32,
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
model_channels=320,
|
||||
attention_resolutions=[4, 2, 1],
|
||||
num_res_blocks=2,
|
||||
channel_mult=[1, 2, 4, 4],
|
||||
num_head_channels=64,
|
||||
use_spatial_transformer=True,
|
||||
use_linear_in_transformer=True,
|
||||
transformer_depth=1,
|
||||
context_dim=1024,
|
||||
legacy=False
|
||||
)
|
||||
unet.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
unet_sd = {k.replace("model.diffusion_model.", ""): v for k, v in state_dict.items() if "model.diffusion_model." in k}
|
||||
unet.load_state_dict(unet_sd, strict=True)
|
||||
unet.to(device=device, dtype=torch.float)
|
||||
|
||||
test_cond = torch.ones((1, 2, 1024), device=device) * 0.5
|
||||
x_test = torch.ones((1, 4, 8, 8), device=device) * 0.5
|
||||
|
||||
out = (unet(x_test, torch.asarray([999], device=device), context=test_cond) - x_test).mean().item()
|
||||
|
||||
return out < -1
|
||||
|
||||
|
||||
def guess_model_config_from_state_dict(sd, filename):
|
||||
sd2_cond_proj_weight = sd.get('cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight', None)
|
||||
diffusion_model_input = sd.get('model.diffusion_model.input_blocks.0.0.weight', None)
|
||||
|
||||
if sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None:
|
||||
return config_depth_model
|
||||
|
||||
if sd2_cond_proj_weight is not None and sd2_cond_proj_weight.shape[1] == 1024:
|
||||
if diffusion_model_input.shape[1] == 9:
|
||||
return config_sd2_inpainting
|
||||
elif is_using_v_parameterization_for_sd2(sd):
|
||||
return config_sd2v
|
||||
else:
|
||||
return config_sd2
|
||||
|
||||
if diffusion_model_input is not None:
|
||||
if diffusion_model_input.shape[1] == 9:
|
||||
return config_inpainting
|
||||
if diffusion_model_input.shape[1] == 8:
|
||||
return config_instruct_pix2pix
|
||||
|
||||
if sd.get('cond_stage_model.roberta.embeddings.word_embeddings.weight', None) is not None:
|
||||
return config_alt_diffusion
|
||||
|
||||
return config_default
|
||||
|
||||
|
||||
def find_checkpoint_config(state_dict, info):
|
||||
if info is None:
|
||||
return guess_model_config_from_state_dict(state_dict, "")
|
||||
|
||||
config = find_checkpoint_config_near_filename(info)
|
||||
if config is not None:
|
||||
return config
|
||||
|
||||
return guess_model_config_from_state_dict(state_dict, info.filename)
|
||||
|
||||
|
||||
def find_checkpoint_config_near_filename(info):
|
||||
if info is None:
|
||||
return None
|
||||
|
||||
config = os.path.splitext(info.filename)[0] + ".yaml"
|
||||
if os.path.exists(config):
|
||||
return config
|
||||
|
||||
return None
|
||||
|
|
@ -1,53 +1,11 @@
|
|||
from collections import namedtuple, deque
|
||||
import numpy as np
|
||||
from math import floor
|
||||
import torch
|
||||
import tqdm
|
||||
from PIL import Image
|
||||
import inspect
|
||||
import k_diffusion.sampling
|
||||
import torchsde._brownian.brownian_interval
|
||||
import ldm.models.diffusion.ddim
|
||||
import ldm.models.diffusion.plms
|
||||
from modules import prompt_parser, devices, processing, images, sd_vae_approx
|
||||
from modules import sd_samplers_compvis, sd_samplers_kdiffusion, shared
|
||||
|
||||
from modules.shared import opts, cmd_opts, state
|
||||
import modules.shared as shared
|
||||
from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback
|
||||
|
||||
|
||||
SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
|
||||
|
||||
samplers_k_diffusion = [
|
||||
('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {}),
|
||||
('Euler', 'sample_euler', ['k_euler'], {}),
|
||||
('LMS', 'sample_lms', ['k_lms'], {}),
|
||||
('Heun', 'sample_heun', ['k_heun'], {}),
|
||||
('DPM2', 'sample_dpm_2', ['k_dpm_2'], {'discard_next_to_last_sigma': True}),
|
||||
('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'discard_next_to_last_sigma': True}),
|
||||
('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {}),
|
||||
('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}),
|
||||
('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {}),
|
||||
('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {}),
|
||||
('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {}),
|
||||
('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}),
|
||||
('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True}),
|
||||
('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True}),
|
||||
('DPM++ 2S a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras'}),
|
||||
('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}),
|
||||
('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras'}),
|
||||
]
|
||||
|
||||
samplers_data_k_diffusion = [
|
||||
SamplerData(label, lambda model, funcname=funcname: KDiffusionSampler(funcname, model), aliases, options)
|
||||
for label, funcname, aliases, options in samplers_k_diffusion
|
||||
if hasattr(k_diffusion.sampling, funcname)
|
||||
]
|
||||
# imports for functions that previously were here and are used by other modules
|
||||
from modules.sd_samplers_common import samples_to_image_grid, sample_to_image
|
||||
|
||||
all_samplers = [
|
||||
*samplers_data_k_diffusion,
|
||||
SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {}),
|
||||
SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {}),
|
||||
*sd_samplers_kdiffusion.samplers_data_k_diffusion,
|
||||
*sd_samplers_compvis.samplers_data_compvis,
|
||||
]
|
||||
all_samplers_map = {x.name: x for x in all_samplers}
|
||||
|
||||
|
@ -73,8 +31,8 @@ def create_sampler(name, model):
|
|||
def set_samplers():
|
||||
global samplers, samplers_for_img2img
|
||||
|
||||
hidden = set(opts.hide_samplers)
|
||||
hidden_img2img = set(opts.hide_samplers + ['PLMS'])
|
||||
hidden = set(shared.opts.hide_samplers)
|
||||
hidden_img2img = set(shared.opts.hide_samplers + ['PLMS'])
|
||||
|
||||
samplers = [x for x in all_samplers if x.name not in hidden]
|
||||
samplers_for_img2img = [x for x in all_samplers if x.name not in hidden_img2img]
|
||||
|
@ -87,466 +45,3 @@ def set_samplers():
|
|||
|
||||
|
||||
set_samplers()
|
||||
|
||||
sampler_extra_params = {
|
||||
'sample_euler': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
|
||||
'sample_heun': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
|
||||
'sample_dpm_2': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
|
||||
}
|
||||
|
||||
|
||||
def setup_img2img_steps(p, steps=None):
|
||||
if opts.img2img_fix_steps or steps is not None:
|
||||
requested_steps = (steps or p.steps)
|
||||
steps = int(requested_steps / min(p.denoising_strength, 0.999)) if p.denoising_strength > 0 else 0
|
||||
t_enc = requested_steps - 1
|
||||
else:
|
||||
steps = p.steps
|
||||
t_enc = int(min(p.denoising_strength, 0.999) * steps)
|
||||
|
||||
return steps, t_enc
|
||||
|
||||
|
||||
approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2}
|
||||
|
||||
|
||||
def single_sample_to_image(sample, approximation=None):
|
||||
if approximation is None:
|
||||
approximation = approximation_indexes.get(opts.show_progress_type, 0)
|
||||
|
||||
if approximation == 2:
|
||||
x_sample = sd_vae_approx.cheap_approximation(sample)
|
||||
elif approximation == 1:
|
||||
x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach()
|
||||
else:
|
||||
x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0]
|
||||
|
||||
x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
|
||||
x_sample = x_sample.astype(np.uint8)
|
||||
return Image.fromarray(x_sample)
|
||||
|
||||
|
||||
def sample_to_image(samples, index=0, approximation=None):
|
||||
return single_sample_to_image(samples[index], approximation)
|
||||
|
||||
|
||||
def samples_to_image_grid(samples, approximation=None):
|
||||
return images.image_grid([single_sample_to_image(sample, approximation) for sample in samples])
|
||||
|
||||
|
||||
def store_latent(decoded):
|
||||
state.current_latent = decoded
|
||||
|
||||
if opts.live_previews_enable and opts.show_progress_every_n_steps > 0 and shared.state.sampling_step % opts.show_progress_every_n_steps == 0:
|
||||
if not shared.parallel_processing_allowed:
|
||||
shared.state.assign_current_image(sample_to_image(decoded))
|
||||
|
||||
|
||||
class InterruptedException(BaseException):
|
||||
pass
|
||||
|
||||
|
||||
class VanillaStableDiffusionSampler:
|
||||
def __init__(self, constructor, sd_model):
|
||||
self.sampler = constructor(sd_model)
|
||||
self.is_plms = hasattr(self.sampler, 'p_sample_plms')
|
||||
self.orig_p_sample_ddim = self.sampler.p_sample_plms if self.is_plms else self.sampler.p_sample_ddim
|
||||
self.mask = None
|
||||
self.nmask = None
|
||||
self.init_latent = None
|
||||
self.sampler_noises = None
|
||||
self.step = 0
|
||||
self.stop_at = None
|
||||
self.eta = None
|
||||
self.default_eta = 0.0
|
||||
self.config = None
|
||||
self.last_latent = None
|
||||
|
||||
self.conditioning_key = sd_model.model.conditioning_key
|
||||
|
||||
def number_of_needed_noises(self, p):
|
||||
return 0
|
||||
|
||||
def launch_sampling(self, steps, func):
|
||||
state.sampling_steps = steps
|
||||
state.sampling_step = 0
|
||||
|
||||
try:
|
||||
return func()
|
||||
except InterruptedException:
|
||||
return self.last_latent
|
||||
|
||||
def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs):
|
||||
if state.interrupted or state.skipped:
|
||||
raise InterruptedException
|
||||
|
||||
if self.stop_at is not None and self.step > self.stop_at:
|
||||
raise InterruptedException
|
||||
|
||||
# Have to unwrap the inpainting conditioning here to perform pre-processing
|
||||
image_conditioning = None
|
||||
if isinstance(cond, dict):
|
||||
image_conditioning = cond["c_concat"][0]
|
||||
cond = cond["c_crossattn"][0]
|
||||
unconditional_conditioning = unconditional_conditioning["c_crossattn"][0]
|
||||
|
||||
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
|
||||
unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step)
|
||||
|
||||
assert all([len(conds) == 1 for conds in conds_list]), 'composition via AND is not supported for DDIM/PLMS samplers'
|
||||
cond = tensor
|
||||
|
||||
# for DDIM, shapes must match, we can't just process cond and uncond independently;
|
||||
# filling unconditional_conditioning with repeats of the last vector to match length is
|
||||
# not 100% correct but should work well enough
|
||||
if unconditional_conditioning.shape[1] < cond.shape[1]:
|
||||
last_vector = unconditional_conditioning[:, -1:]
|
||||
last_vector_repeated = last_vector.repeat([1, cond.shape[1] - unconditional_conditioning.shape[1], 1])
|
||||
unconditional_conditioning = torch.hstack([unconditional_conditioning, last_vector_repeated])
|
||||
elif unconditional_conditioning.shape[1] > cond.shape[1]:
|
||||
unconditional_conditioning = unconditional_conditioning[:, :cond.shape[1]]
|
||||
|
||||
if self.mask is not None:
|
||||
img_orig = self.sampler.model.q_sample(self.init_latent, ts)
|
||||
x_dec = img_orig * self.mask + self.nmask * x_dec
|
||||
|
||||
# Wrap the image conditioning back up since the DDIM code can accept the dict directly.
|
||||
# Note that they need to be lists because it just concatenates them later.
|
||||
if image_conditioning is not None:
|
||||
cond = {"c_concat": [image_conditioning], "c_crossattn": [cond]}
|
||||
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
|
||||
|
||||
res = self.orig_p_sample_ddim(x_dec, cond, ts, unconditional_conditioning=unconditional_conditioning, *args, **kwargs)
|
||||
|
||||
if self.mask is not None:
|
||||
self.last_latent = self.init_latent * self.mask + self.nmask * res[1]
|
||||
else:
|
||||
self.last_latent = res[1]
|
||||
|
||||
store_latent(self.last_latent)
|
||||
|
||||
self.step += 1
|
||||
state.sampling_step = self.step
|
||||
shared.total_tqdm.update()
|
||||
|
||||
return res
|
||||
|
||||
def initialize(self, p):
|
||||
self.eta = p.eta if p.eta is not None else opts.eta_ddim
|
||||
|
||||
for fieldname in ['p_sample_ddim', 'p_sample_plms']:
|
||||
if hasattr(self.sampler, fieldname):
|
||||
setattr(self.sampler, fieldname, self.p_sample_ddim_hook)
|
||||
|
||||
self.mask = p.mask if hasattr(p, 'mask') else None
|
||||
self.nmask = p.nmask if hasattr(p, 'nmask') else None
|
||||
|
||||
def adjust_steps_if_invalid(self, p, num_steps):
|
||||
if (self.config.name == 'DDIM' and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS'):
|
||||
valid_step = 999 / (1000 // num_steps)
|
||||
if valid_step == floor(valid_step):
|
||||
return int(valid_step) + 1
|
||||
|
||||
return num_steps
|
||||
|
||||
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
||||
steps, t_enc = setup_img2img_steps(p, steps)
|
||||
steps = self.adjust_steps_if_invalid(p, steps)
|
||||
self.initialize(p)
|
||||
|
||||
self.sampler.make_schedule(ddim_num_steps=steps, ddim_eta=self.eta, ddim_discretize=p.ddim_discretize, verbose=False)
|
||||
x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(shared.device), noise=noise)
|
||||
|
||||
self.init_latent = x
|
||||
self.last_latent = x
|
||||
self.step = 0
|
||||
|
||||
# Wrap the conditioning models with additional image conditioning for inpainting model
|
||||
if image_conditioning is not None:
|
||||
conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]}
|
||||
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
|
||||
|
||||
samples = self.launch_sampling(t_enc + 1, lambda: self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning))
|
||||
|
||||
return samples
|
||||
|
||||
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
||||
self.initialize(p)
|
||||
|
||||
self.init_latent = None
|
||||
self.last_latent = x
|
||||
self.step = 0
|
||||
|
||||
steps = self.adjust_steps_if_invalid(p, steps or p.steps)
|
||||
|
||||
# Wrap the conditioning models with additional image conditioning for inpainting model
|
||||
# dummy_for_plms is needed because PLMS code checks the first item in the dict to have the right shape
|
||||
if image_conditioning is not None:
|
||||
conditioning = {"dummy_for_plms": np.zeros((conditioning.shape[0],)), "c_crossattn": [conditioning], "c_concat": [image_conditioning]}
|
||||
unconditional_conditioning = {"c_crossattn": [unconditional_conditioning], "c_concat": [image_conditioning]}
|
||||
|
||||
samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
|
||||
|
||||
return samples_ddim
|
||||
|
||||
|
||||
class CFGDenoiser(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.inner_model = model
|
||||
self.mask = None
|
||||
self.nmask = None
|
||||
self.init_latent = None
|
||||
self.step = 0
|
||||
|
||||
def combine_denoised(self, x_out, conds_list, uncond, cond_scale):
|
||||
denoised_uncond = x_out[-uncond.shape[0]:]
|
||||
denoised = torch.clone(denoised_uncond)
|
||||
|
||||
for i, conds in enumerate(conds_list):
|
||||
for cond_index, weight in conds:
|
||||
denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale)
|
||||
|
||||
return denoised
|
||||
|
||||
def forward(self, x, sigma, uncond, cond, cond_scale, image_cond):
|
||||
if state.interrupted or state.skipped:
|
||||
raise InterruptedException
|
||||
|
||||
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
|
||||
uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
|
||||
|
||||
batch_size = len(conds_list)
|
||||
repeats = [len(conds_list[i]) for i in range(batch_size)]
|
||||
|
||||
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
|
||||
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond])
|
||||
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
|
||||
|
||||
denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps)
|
||||
cfg_denoiser_callback(denoiser_params)
|
||||
x_in = denoiser_params.x
|
||||
image_cond_in = denoiser_params.image_cond
|
||||
sigma_in = denoiser_params.sigma
|
||||
|
||||
if tensor.shape[1] == uncond.shape[1]:
|
||||
cond_in = torch.cat([tensor, uncond])
|
||||
|
||||
if shared.batch_cond_uncond:
|
||||
x_out = self.inner_model(x_in, sigma_in, cond={"c_crossattn": [cond_in], "c_concat": [image_cond_in]})
|
||||
else:
|
||||
x_out = torch.zeros_like(x_in)
|
||||
for batch_offset in range(0, x_out.shape[0], batch_size):
|
||||
a = batch_offset
|
||||
b = a + batch_size
|
||||
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": [cond_in[a:b]], "c_concat": [image_cond_in[a:b]]})
|
||||
else:
|
||||
x_out = torch.zeros_like(x_in)
|
||||
batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size
|
||||
for batch_offset in range(0, tensor.shape[0], batch_size):
|
||||
a = batch_offset
|
||||
b = min(a + batch_size, tensor.shape[0])
|
||||
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": [tensor[a:b]], "c_concat": [image_cond_in[a:b]]})
|
||||
|
||||
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond={"c_crossattn": [uncond], "c_concat": [image_cond_in[-uncond.shape[0]:]]})
|
||||
|
||||
devices.test_for_nans(x_out, "unet")
|
||||
|
||||
if opts.live_preview_content == "Prompt":
|
||||
store_latent(x_out[0:uncond.shape[0]])
|
||||
elif opts.live_preview_content == "Negative prompt":
|
||||
store_latent(x_out[-uncond.shape[0]:])
|
||||
|
||||
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
|
||||
|
||||
if self.mask is not None:
|
||||
denoised = self.init_latent * self.mask + self.nmask * denoised
|
||||
|
||||
self.step += 1
|
||||
|
||||
return denoised
|
||||
|
||||
|
||||
class TorchHijack:
|
||||
def __init__(self, sampler_noises):
|
||||
# Using a deque to efficiently receive the sampler_noises in the same order as the previous index-based
|
||||
# implementation.
|
||||
self.sampler_noises = deque(sampler_noises)
|
||||
|
||||
def __getattr__(self, item):
|
||||
if item == 'randn_like':
|
||||
return self.randn_like
|
||||
|
||||
if hasattr(torch, item):
|
||||
return getattr(torch, item)
|
||||
|
||||
raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item))
|
||||
|
||||
def randn_like(self, x):
|
||||
if self.sampler_noises:
|
||||
noise = self.sampler_noises.popleft()
|
||||
if noise.shape == x.shape:
|
||||
return noise
|
||||
|
||||
if x.device.type == 'mps':
|
||||
return torch.randn_like(x, device=devices.cpu).to(x.device)
|
||||
else:
|
||||
return torch.randn_like(x)
|
||||
|
||||
|
||||
# MPS fix for randn in torchsde
|
||||
def torchsde_randn(size, dtype, device, seed):
|
||||
if device.type == 'mps':
|
||||
generator = torch.Generator(devices.cpu).manual_seed(int(seed))
|
||||
return torch.randn(size, dtype=dtype, device=devices.cpu, generator=generator).to(device)
|
||||
else:
|
||||
generator = torch.Generator(device).manual_seed(int(seed))
|
||||
return torch.randn(size, dtype=dtype, device=device, generator=generator)
|
||||
|
||||
|
||||
torchsde._brownian.brownian_interval._randn = torchsde_randn
|
||||
|
||||
|
||||
class KDiffusionSampler:
|
||||
def __init__(self, funcname, sd_model):
|
||||
denoiser = k_diffusion.external.CompVisVDenoiser if sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser
|
||||
|
||||
self.model_wrap = denoiser(sd_model, quantize=shared.opts.enable_quantization)
|
||||
self.funcname = funcname
|
||||
self.func = getattr(k_diffusion.sampling, self.funcname)
|
||||
self.extra_params = sampler_extra_params.get(funcname, [])
|
||||
self.model_wrap_cfg = CFGDenoiser(self.model_wrap)
|
||||
self.sampler_noises = None
|
||||
self.stop_at = None
|
||||
self.eta = None
|
||||
self.default_eta = 1.0
|
||||
self.config = None
|
||||
self.last_latent = None
|
||||
|
||||
self.conditioning_key = sd_model.model.conditioning_key
|
||||
|
||||
def callback_state(self, d):
|
||||
step = d['i']
|
||||
latent = d["denoised"]
|
||||
if opts.live_preview_content == "Combined":
|
||||
store_latent(latent)
|
||||
self.last_latent = latent
|
||||
|
||||
if self.stop_at is not None and step > self.stop_at:
|
||||
raise InterruptedException
|
||||
|
||||
state.sampling_step = step
|
||||
shared.total_tqdm.update()
|
||||
|
||||
def launch_sampling(self, steps, func):
|
||||
state.sampling_steps = steps
|
||||
state.sampling_step = 0
|
||||
|
||||
try:
|
||||
return func()
|
||||
except InterruptedException:
|
||||
return self.last_latent
|
||||
|
||||
def number_of_needed_noises(self, p):
|
||||
return p.steps
|
||||
|
||||
def initialize(self, p):
|
||||
self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None
|
||||
self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None
|
||||
self.model_wrap.step = 0
|
||||
self.eta = p.eta or opts.eta_ancestral
|
||||
|
||||
k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else [])
|
||||
|
||||
extra_params_kwargs = {}
|
||||
for param_name in self.extra_params:
|
||||
if hasattr(p, param_name) and param_name in inspect.signature(self.func).parameters:
|
||||
extra_params_kwargs[param_name] = getattr(p, param_name)
|
||||
|
||||
if 'eta' in inspect.signature(self.func).parameters:
|
||||
extra_params_kwargs['eta'] = self.eta
|
||||
|
||||
return extra_params_kwargs
|
||||
|
||||
def get_sigmas(self, p, steps):
|
||||
discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False)
|
||||
if opts.always_discard_next_to_last_sigma and not discard_next_to_last_sigma:
|
||||
discard_next_to_last_sigma = True
|
||||
p.extra_generation_params["Discard penultimate sigma"] = True
|
||||
|
||||
steps += 1 if discard_next_to_last_sigma else 0
|
||||
|
||||
if p.sampler_noise_scheduler_override:
|
||||
sigmas = p.sampler_noise_scheduler_override(steps)
|
||||
elif self.config is not None and self.config.options.get('scheduler', None) == 'karras':
|
||||
sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item())
|
||||
|
||||
sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, device=shared.device)
|
||||
else:
|
||||
sigmas = self.model_wrap.get_sigmas(steps)
|
||||
|
||||
if discard_next_to_last_sigma:
|
||||
sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])
|
||||
|
||||
return sigmas
|
||||
|
||||
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
||||
steps, t_enc = setup_img2img_steps(p, steps)
|
||||
|
||||
sigmas = self.get_sigmas(p, steps)
|
||||
|
||||
sigma_sched = sigmas[steps - t_enc - 1:]
|
||||
xi = x + noise * sigma_sched[0]
|
||||
|
||||
extra_params_kwargs = self.initialize(p)
|
||||
if 'sigma_min' in inspect.signature(self.func).parameters:
|
||||
## last sigma is zero which isn't allowed by DPM Fast & Adaptive so taking value before last
|
||||
extra_params_kwargs['sigma_min'] = sigma_sched[-2]
|
||||
if 'sigma_max' in inspect.signature(self.func).parameters:
|
||||
extra_params_kwargs['sigma_max'] = sigma_sched[0]
|
||||
if 'n' in inspect.signature(self.func).parameters:
|
||||
extra_params_kwargs['n'] = len(sigma_sched) - 1
|
||||
if 'sigma_sched' in inspect.signature(self.func).parameters:
|
||||
extra_params_kwargs['sigma_sched'] = sigma_sched
|
||||
if 'sigmas' in inspect.signature(self.func).parameters:
|
||||
extra_params_kwargs['sigmas'] = sigma_sched
|
||||
|
||||
self.model_wrap_cfg.init_latent = x
|
||||
self.last_latent = x
|
||||
|
||||
samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args={
|
||||
'cond': conditioning,
|
||||
'image_cond': image_conditioning,
|
||||
'uncond': unconditional_conditioning,
|
||||
'cond_scale': p.cfg_scale
|
||||
}, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
||||
|
||||
return samples
|
||||
|
||||
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning = None):
|
||||
steps = steps or p.steps
|
||||
|
||||
sigmas = self.get_sigmas(p, steps)
|
||||
|
||||
x = x * sigmas[0]
|
||||
|
||||
extra_params_kwargs = self.initialize(p)
|
||||
if 'sigma_min' in inspect.signature(self.func).parameters:
|
||||
extra_params_kwargs['sigma_min'] = self.model_wrap.sigmas[0].item()
|
||||
extra_params_kwargs['sigma_max'] = self.model_wrap.sigmas[-1].item()
|
||||
if 'n' in inspect.signature(self.func).parameters:
|
||||
extra_params_kwargs['n'] = steps
|
||||
else:
|
||||
extra_params_kwargs['sigmas'] = sigmas
|
||||
|
||||
self.last_latent = x
|
||||
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={
|
||||
'cond': conditioning,
|
||||
'image_cond': image_conditioning,
|
||||
'uncond': unconditional_conditioning,
|
||||
'cond_scale': p.cfg_scale
|
||||
}, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
||||
|
||||
return samples
|
||||
|
||||
|
|
62
modules/sd_samplers_common.py
Normal file
62
modules/sd_samplers_common.py
Normal file
|
@ -0,0 +1,62 @@
|
|||
from collections import namedtuple
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from modules import devices, processing, images, sd_vae_approx
|
||||
|
||||
from modules.shared import opts, state
|
||||
import modules.shared as shared
|
||||
|
||||
SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
|
||||
|
||||
|
||||
def setup_img2img_steps(p, steps=None):
|
||||
if opts.img2img_fix_steps or steps is not None:
|
||||
requested_steps = (steps or p.steps)
|
||||
steps = int(requested_steps / min(p.denoising_strength, 0.999)) if p.denoising_strength > 0 else 0
|
||||
t_enc = requested_steps - 1
|
||||
else:
|
||||
steps = p.steps
|
||||
t_enc = int(min(p.denoising_strength, 0.999) * steps)
|
||||
|
||||
return steps, t_enc
|
||||
|
||||
|
||||
approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2}
|
||||
|
||||
|
||||
def single_sample_to_image(sample, approximation=None):
|
||||
if approximation is None:
|
||||
approximation = approximation_indexes.get(opts.show_progress_type, 0)
|
||||
|
||||
if approximation == 2:
|
||||
x_sample = sd_vae_approx.cheap_approximation(sample)
|
||||
elif approximation == 1:
|
||||
x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach()
|
||||
else:
|
||||
x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0]
|
||||
|
||||
x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
|
||||
x_sample = x_sample.astype(np.uint8)
|
||||
return Image.fromarray(x_sample)
|
||||
|
||||
|
||||
def sample_to_image(samples, index=0, approximation=None):
|
||||
return single_sample_to_image(samples[index], approximation)
|
||||
|
||||
|
||||
def samples_to_image_grid(samples, approximation=None):
|
||||
return images.image_grid([single_sample_to_image(sample, approximation) for sample in samples])
|
||||
|
||||
|
||||
def store_latent(decoded):
|
||||
state.current_latent = decoded
|
||||
|
||||
if opts.live_previews_enable and opts.show_progress_every_n_steps > 0 and shared.state.sampling_step % opts.show_progress_every_n_steps == 0:
|
||||
if not shared.parallel_processing_allowed:
|
||||
shared.state.assign_current_image(sample_to_image(decoded))
|
||||
|
||||
|
||||
class InterruptedException(BaseException):
|
||||
pass
|
160
modules/sd_samplers_compvis.py
Normal file
160
modules/sd_samplers_compvis.py
Normal file
|
@ -0,0 +1,160 @@
|
|||
import math
|
||||
import ldm.models.diffusion.ddim
|
||||
import ldm.models.diffusion.plms
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from modules.shared import state
|
||||
from modules import sd_samplers_common, prompt_parser, shared
|
||||
|
||||
|
||||
samplers_data_compvis = [
|
||||
sd_samplers_common.SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {}),
|
||||
sd_samplers_common.SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {}),
|
||||
]
|
||||
|
||||
|
||||
class VanillaStableDiffusionSampler:
|
||||
def __init__(self, constructor, sd_model):
|
||||
self.sampler = constructor(sd_model)
|
||||
self.is_plms = hasattr(self.sampler, 'p_sample_plms')
|
||||
self.orig_p_sample_ddim = self.sampler.p_sample_plms if self.is_plms else self.sampler.p_sample_ddim
|
||||
self.mask = None
|
||||
self.nmask = None
|
||||
self.init_latent = None
|
||||
self.sampler_noises = None
|
||||
self.step = 0
|
||||
self.stop_at = None
|
||||
self.eta = None
|
||||
self.config = None
|
||||
self.last_latent = None
|
||||
|
||||
self.conditioning_key = sd_model.model.conditioning_key
|
||||
|
||||
def number_of_needed_noises(self, p):
|
||||
return 0
|
||||
|
||||
def launch_sampling(self, steps, func):
|
||||
state.sampling_steps = steps
|
||||
state.sampling_step = 0
|
||||
|
||||
try:
|
||||
return func()
|
||||
except sd_samplers_common.InterruptedException:
|
||||
return self.last_latent
|
||||
|
||||
def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs):
|
||||
if state.interrupted or state.skipped:
|
||||
raise sd_samplers_common.InterruptedException
|
||||
|
||||
if self.stop_at is not None and self.step > self.stop_at:
|
||||
raise sd_samplers_common.InterruptedException
|
||||
|
||||
# Have to unwrap the inpainting conditioning here to perform pre-processing
|
||||
image_conditioning = None
|
||||
if isinstance(cond, dict):
|
||||
image_conditioning = cond["c_concat"][0]
|
||||
cond = cond["c_crossattn"][0]
|
||||
unconditional_conditioning = unconditional_conditioning["c_crossattn"][0]
|
||||
|
||||
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
|
||||
unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step)
|
||||
|
||||
assert all([len(conds) == 1 for conds in conds_list]), 'composition via AND is not supported for DDIM/PLMS samplers'
|
||||
cond = tensor
|
||||
|
||||
# for DDIM, shapes must match, we can't just process cond and uncond independently;
|
||||
# filling unconditional_conditioning with repeats of the last vector to match length is
|
||||
# not 100% correct but should work well enough
|
||||
if unconditional_conditioning.shape[1] < cond.shape[1]:
|
||||
last_vector = unconditional_conditioning[:, -1:]
|
||||
last_vector_repeated = last_vector.repeat([1, cond.shape[1] - unconditional_conditioning.shape[1], 1])
|
||||
unconditional_conditioning = torch.hstack([unconditional_conditioning, last_vector_repeated])
|
||||
elif unconditional_conditioning.shape[1] > cond.shape[1]:
|
||||
unconditional_conditioning = unconditional_conditioning[:, :cond.shape[1]]
|
||||
|
||||
if self.mask is not None:
|
||||
img_orig = self.sampler.model.q_sample(self.init_latent, ts)
|
||||
x_dec = img_orig * self.mask + self.nmask * x_dec
|
||||
|
||||
# Wrap the image conditioning back up since the DDIM code can accept the dict directly.
|
||||
# Note that they need to be lists because it just concatenates them later.
|
||||
if image_conditioning is not None:
|
||||
cond = {"c_concat": [image_conditioning], "c_crossattn": [cond]}
|
||||
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
|
||||
|
||||
res = self.orig_p_sample_ddim(x_dec, cond, ts, unconditional_conditioning=unconditional_conditioning, *args, **kwargs)
|
||||
|
||||
if self.mask is not None:
|
||||
self.last_latent = self.init_latent * self.mask + self.nmask * res[1]
|
||||
else:
|
||||
self.last_latent = res[1]
|
||||
|
||||
sd_samplers_common.store_latent(self.last_latent)
|
||||
|
||||
self.step += 1
|
||||
state.sampling_step = self.step
|
||||
shared.total_tqdm.update()
|
||||
|
||||
return res
|
||||
|
||||
def initialize(self, p):
|
||||
self.eta = p.eta if p.eta is not None else shared.opts.eta_ddim
|
||||
if self.eta != 0.0:
|
||||
p.extra_generation_params["Eta DDIM"] = self.eta
|
||||
|
||||
for fieldname in ['p_sample_ddim', 'p_sample_plms']:
|
||||
if hasattr(self.sampler, fieldname):
|
||||
setattr(self.sampler, fieldname, self.p_sample_ddim_hook)
|
||||
|
||||
self.mask = p.mask if hasattr(p, 'mask') else None
|
||||
self.nmask = p.nmask if hasattr(p, 'nmask') else None
|
||||
|
||||
def adjust_steps_if_invalid(self, p, num_steps):
|
||||
if (self.config.name == 'DDIM' and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS'):
|
||||
valid_step = 999 / (1000 // num_steps)
|
||||
if valid_step == math.floor(valid_step):
|
||||
return int(valid_step) + 1
|
||||
|
||||
return num_steps
|
||||
|
||||
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
||||
steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps)
|
||||
steps = self.adjust_steps_if_invalid(p, steps)
|
||||
self.initialize(p)
|
||||
|
||||
self.sampler.make_schedule(ddim_num_steps=steps, ddim_eta=self.eta, ddim_discretize=p.ddim_discretize, verbose=False)
|
||||
x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(shared.device), noise=noise)
|
||||
|
||||
self.init_latent = x
|
||||
self.last_latent = x
|
||||
self.step = 0
|
||||
|
||||
# Wrap the conditioning models with additional image conditioning for inpainting model
|
||||
if image_conditioning is not None:
|
||||
conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]}
|
||||
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
|
||||
|
||||
samples = self.launch_sampling(t_enc + 1, lambda: self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning))
|
||||
|
||||
return samples
|
||||
|
||||
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
||||
self.initialize(p)
|
||||
|
||||
self.init_latent = None
|
||||
self.last_latent = x
|
||||
self.step = 0
|
||||
|
||||
steps = self.adjust_steps_if_invalid(p, steps or p.steps)
|
||||
|
||||
# Wrap the conditioning models with additional image conditioning for inpainting model
|
||||
# dummy_for_plms is needed because PLMS code checks the first item in the dict to have the right shape
|
||||
if image_conditioning is not None:
|
||||
conditioning = {"dummy_for_plms": np.zeros((conditioning.shape[0],)), "c_crossattn": [conditioning], "c_concat": [image_conditioning]}
|
||||
unconditional_conditioning = {"c_crossattn": [unconditional_conditioning], "c_concat": [image_conditioning]}
|
||||
|
||||
samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
|
||||
|
||||
return samples_ddim
|
331
modules/sd_samplers_kdiffusion.py
Normal file
331
modules/sd_samplers_kdiffusion.py
Normal file
|
@ -0,0 +1,331 @@
|
|||
from collections import deque
|
||||
import torch
|
||||
import inspect
|
||||
import einops
|
||||
import k_diffusion.sampling
|
||||
from modules import prompt_parser, devices, sd_samplers_common
|
||||
|
||||
from modules.shared import opts, state
|
||||
import modules.shared as shared
|
||||
from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback
|
||||
|
||||
samplers_k_diffusion = [
|
||||
('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {}),
|
||||
('Euler', 'sample_euler', ['k_euler'], {}),
|
||||
('LMS', 'sample_lms', ['k_lms'], {}),
|
||||
('Heun', 'sample_heun', ['k_heun'], {}),
|
||||
('DPM2', 'sample_dpm_2', ['k_dpm_2'], {'discard_next_to_last_sigma': True}),
|
||||
('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'discard_next_to_last_sigma': True}),
|
||||
('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {}),
|
||||
('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}),
|
||||
('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {}),
|
||||
('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {}),
|
||||
('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {}),
|
||||
('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}),
|
||||
('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True}),
|
||||
('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True}),
|
||||
('DPM++ 2S a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras'}),
|
||||
('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}),
|
||||
('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras'}),
|
||||
]
|
||||
|
||||
samplers_data_k_diffusion = [
|
||||
sd_samplers_common.SamplerData(label, lambda model, funcname=funcname: KDiffusionSampler(funcname, model), aliases, options)
|
||||
for label, funcname, aliases, options in samplers_k_diffusion
|
||||
if hasattr(k_diffusion.sampling, funcname)
|
||||
]
|
||||
|
||||
sampler_extra_params = {
|
||||
'sample_euler': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
|
||||
'sample_heun': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
|
||||
'sample_dpm_2': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
|
||||
}
|
||||
|
||||
|
||||
class CFGDenoiser(torch.nn.Module):
|
||||
"""
|
||||
Classifier free guidance denoiser. A wrapper for stable diffusion model (specifically for unet)
|
||||
that can take a noisy picture and produce a noise-free picture using two guidances (prompts)
|
||||
instead of one. Originally, the second prompt is just an empty string, but we use non-empty
|
||||
negative prompt.
|
||||
"""
|
||||
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.inner_model = model
|
||||
self.mask = None
|
||||
self.nmask = None
|
||||
self.init_latent = None
|
||||
self.step = 0
|
||||
self.image_cfg_scale = None
|
||||
|
||||
def combine_denoised(self, x_out, conds_list, uncond, cond_scale):
|
||||
denoised_uncond = x_out[-uncond.shape[0]:]
|
||||
denoised = torch.clone(denoised_uncond)
|
||||
|
||||
for i, conds in enumerate(conds_list):
|
||||
for cond_index, weight in conds:
|
||||
denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale)
|
||||
|
||||
return denoised
|
||||
|
||||
def combine_denoised_for_edit_model(self, x_out, cond_scale):
|
||||
out_cond, out_img_cond, out_uncond = x_out.chunk(3)
|
||||
denoised = out_uncond + cond_scale * (out_cond - out_img_cond) + self.image_cfg_scale * (out_img_cond - out_uncond)
|
||||
|
||||
return denoised
|
||||
|
||||
def forward(self, x, sigma, uncond, cond, cond_scale, image_cond):
|
||||
if state.interrupted or state.skipped:
|
||||
raise sd_samplers_common.InterruptedException
|
||||
|
||||
# at self.image_cfg_scale == 1.0 produced results for edit model are the same as with normal sampling,
|
||||
# so is_edit_model is set to False to support AND composition.
|
||||
is_edit_model = shared.sd_model.cond_stage_key == "edit" and self.image_cfg_scale is not None and self.image_cfg_scale != 1.0
|
||||
|
||||
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
|
||||
uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
|
||||
|
||||
assert not is_edit_model or all([len(conds) == 1 for conds in conds_list]), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"
|
||||
|
||||
batch_size = len(conds_list)
|
||||
repeats = [len(conds_list[i]) for i in range(batch_size)]
|
||||
|
||||
if not is_edit_model:
|
||||
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
|
||||
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
|
||||
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond])
|
||||
else:
|
||||
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x] + [x])
|
||||
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma])
|
||||
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond] + [torch.zeros_like(self.init_latent)])
|
||||
|
||||
denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps)
|
||||
cfg_denoiser_callback(denoiser_params)
|
||||
x_in = denoiser_params.x
|
||||
image_cond_in = denoiser_params.image_cond
|
||||
sigma_in = denoiser_params.sigma
|
||||
|
||||
if tensor.shape[1] == uncond.shape[1]:
|
||||
if not is_edit_model:
|
||||
cond_in = torch.cat([tensor, uncond])
|
||||
else:
|
||||
cond_in = torch.cat([tensor, uncond, uncond])
|
||||
|
||||
if shared.batch_cond_uncond:
|
||||
x_out = self.inner_model(x_in, sigma_in, cond={"c_crossattn": [cond_in], "c_concat": [image_cond_in]})
|
||||
else:
|
||||
x_out = torch.zeros_like(x_in)
|
||||
for batch_offset in range(0, x_out.shape[0], batch_size):
|
||||
a = batch_offset
|
||||
b = a + batch_size
|
||||
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": [cond_in[a:b]], "c_concat": [image_cond_in[a:b]]})
|
||||
else:
|
||||
x_out = torch.zeros_like(x_in)
|
||||
batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size
|
||||
for batch_offset in range(0, tensor.shape[0], batch_size):
|
||||
a = batch_offset
|
||||
b = min(a + batch_size, tensor.shape[0])
|
||||
|
||||
if not is_edit_model:
|
||||
c_crossattn = [tensor[a:b]]
|
||||
else:
|
||||
c_crossattn = torch.cat([tensor[a:b]], uncond)
|
||||
|
||||
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": c_crossattn, "c_concat": [image_cond_in[a:b]]})
|
||||
|
||||
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond={"c_crossattn": [uncond], "c_concat": [image_cond_in[-uncond.shape[0]:]]})
|
||||
|
||||
devices.test_for_nans(x_out, "unet")
|
||||
|
||||
if opts.live_preview_content == "Prompt":
|
||||
sd_samplers_common.store_latent(x_out[0:uncond.shape[0]])
|
||||
elif opts.live_preview_content == "Negative prompt":
|
||||
sd_samplers_common.store_latent(x_out[-uncond.shape[0]:])
|
||||
|
||||
if not is_edit_model:
|
||||
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
|
||||
else:
|
||||
denoised = self.combine_denoised_for_edit_model(x_out, cond_scale)
|
||||
|
||||
if self.mask is not None:
|
||||
denoised = self.init_latent * self.mask + self.nmask * denoised
|
||||
|
||||
self.step += 1
|
||||
|
||||
return denoised
|
||||
|
||||
|
||||
class TorchHijack:
|
||||
def __init__(self, sampler_noises):
|
||||
# Using a deque to efficiently receive the sampler_noises in the same order as the previous index-based
|
||||
# implementation.
|
||||
self.sampler_noises = deque(sampler_noises)
|
||||
|
||||
def __getattr__(self, item):
|
||||
if item == 'randn_like':
|
||||
return self.randn_like
|
||||
|
||||
if hasattr(torch, item):
|
||||
return getattr(torch, item)
|
||||
|
||||
raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item))
|
||||
|
||||
def randn_like(self, x):
|
||||
if self.sampler_noises:
|
||||
noise = self.sampler_noises.popleft()
|
||||
if noise.shape == x.shape:
|
||||
return noise
|
||||
|
||||
if x.device.type == 'mps':
|
||||
return torch.randn_like(x, device=devices.cpu).to(x.device)
|
||||
else:
|
||||
return torch.randn_like(x)
|
||||
|
||||
|
||||
class KDiffusionSampler:
|
||||
def __init__(self, funcname, sd_model):
|
||||
denoiser = k_diffusion.external.CompVisVDenoiser if sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser
|
||||
|
||||
self.model_wrap = denoiser(sd_model, quantize=shared.opts.enable_quantization)
|
||||
self.funcname = funcname
|
||||
self.func = getattr(k_diffusion.sampling, self.funcname)
|
||||
self.extra_params = sampler_extra_params.get(funcname, [])
|
||||
self.model_wrap_cfg = CFGDenoiser(self.model_wrap)
|
||||
self.sampler_noises = None
|
||||
self.stop_at = None
|
||||
self.eta = None
|
||||
self.config = None
|
||||
self.last_latent = None
|
||||
|
||||
self.conditioning_key = sd_model.model.conditioning_key
|
||||
|
||||
def callback_state(self, d):
|
||||
step = d['i']
|
||||
latent = d["denoised"]
|
||||
if opts.live_preview_content == "Combined":
|
||||
sd_samplers_common.store_latent(latent)
|
||||
self.last_latent = latent
|
||||
|
||||
if self.stop_at is not None and step > self.stop_at:
|
||||
raise sd_samplers_common.InterruptedException
|
||||
|
||||
state.sampling_step = step
|
||||
shared.total_tqdm.update()
|
||||
|
||||
def launch_sampling(self, steps, func):
|
||||
state.sampling_steps = steps
|
||||
state.sampling_step = 0
|
||||
|
||||
try:
|
||||
return func()
|
||||
except sd_samplers_common.InterruptedException:
|
||||
return self.last_latent
|
||||
|
||||
def number_of_needed_noises(self, p):
|
||||
return p.steps
|
||||
|
||||
def initialize(self, p):
|
||||
self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None
|
||||
self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None
|
||||
self.model_wrap_cfg.step = 0
|
||||
self.model_wrap_cfg.image_cfg_scale = getattr(p, 'image_cfg_scale', None)
|
||||
self.eta = p.eta if p.eta is not None else opts.eta_ancestral
|
||||
|
||||
k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else [])
|
||||
|
||||
extra_params_kwargs = {}
|
||||
for param_name in self.extra_params:
|
||||
if hasattr(p, param_name) and param_name in inspect.signature(self.func).parameters:
|
||||
extra_params_kwargs[param_name] = getattr(p, param_name)
|
||||
|
||||
if 'eta' in inspect.signature(self.func).parameters:
|
||||
if self.eta != 1.0:
|
||||
p.extra_generation_params["Eta"] = self.eta
|
||||
|
||||
extra_params_kwargs['eta'] = self.eta
|
||||
|
||||
return extra_params_kwargs
|
||||
|
||||
def get_sigmas(self, p, steps):
|
||||
discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False)
|
||||
if opts.always_discard_next_to_last_sigma and not discard_next_to_last_sigma:
|
||||
discard_next_to_last_sigma = True
|
||||
p.extra_generation_params["Discard penultimate sigma"] = True
|
||||
|
||||
steps += 1 if discard_next_to_last_sigma else 0
|
||||
|
||||
if p.sampler_noise_scheduler_override:
|
||||
sigmas = p.sampler_noise_scheduler_override(steps)
|
||||
elif self.config is not None and self.config.options.get('scheduler', None) == 'karras':
|
||||
sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item())
|
||||
|
||||
sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, device=shared.device)
|
||||
else:
|
||||
sigmas = self.model_wrap.get_sigmas(steps)
|
||||
|
||||
if discard_next_to_last_sigma:
|
||||
sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])
|
||||
|
||||
return sigmas
|
||||
|
||||
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
||||
steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps)
|
||||
|
||||
sigmas = self.get_sigmas(p, steps)
|
||||
|
||||
sigma_sched = sigmas[steps - t_enc - 1:]
|
||||
xi = x + noise * sigma_sched[0]
|
||||
|
||||
extra_params_kwargs = self.initialize(p)
|
||||
if 'sigma_min' in inspect.signature(self.func).parameters:
|
||||
## last sigma is zero which isn't allowed by DPM Fast & Adaptive so taking value before last
|
||||
extra_params_kwargs['sigma_min'] = sigma_sched[-2]
|
||||
if 'sigma_max' in inspect.signature(self.func).parameters:
|
||||
extra_params_kwargs['sigma_max'] = sigma_sched[0]
|
||||
if 'n' in inspect.signature(self.func).parameters:
|
||||
extra_params_kwargs['n'] = len(sigma_sched) - 1
|
||||
if 'sigma_sched' in inspect.signature(self.func).parameters:
|
||||
extra_params_kwargs['sigma_sched'] = sigma_sched
|
||||
if 'sigmas' in inspect.signature(self.func).parameters:
|
||||
extra_params_kwargs['sigmas'] = sigma_sched
|
||||
|
||||
self.model_wrap_cfg.init_latent = x
|
||||
self.last_latent = x
|
||||
extra_args={
|
||||
'cond': conditioning,
|
||||
'image_cond': image_conditioning,
|
||||
'uncond': unconditional_conditioning,
|
||||
'cond_scale': p.cfg_scale,
|
||||
}
|
||||
|
||||
samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
||||
|
||||
return samples
|
||||
|
||||
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning = None):
|
||||
steps = steps or p.steps
|
||||
|
||||
sigmas = self.get_sigmas(p, steps)
|
||||
|
||||
x = x * sigmas[0]
|
||||
|
||||
extra_params_kwargs = self.initialize(p)
|
||||
if 'sigma_min' in inspect.signature(self.func).parameters:
|
||||
extra_params_kwargs['sigma_min'] = self.model_wrap.sigmas[0].item()
|
||||
extra_params_kwargs['sigma_max'] = self.model_wrap.sigmas[-1].item()
|
||||
if 'n' in inspect.signature(self.func).parameters:
|
||||
extra_params_kwargs['n'] = steps
|
||||
else:
|
||||
extra_params_kwargs['sigmas'] = sigmas
|
||||
|
||||
self.last_latent = x
|
||||
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={
|
||||
'cond': conditioning,
|
||||
'image_cond': image_conditioning,
|
||||
'uncond': unconditional_conditioning,
|
||||
'cond_scale': p.cfg_scale
|
||||
}, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
||||
|
||||
return samples
|
||||
|
|
@ -3,13 +3,12 @@ import safetensors.torch
|
|||
import os
|
||||
import collections
|
||||
from collections import namedtuple
|
||||
from modules import shared, devices, script_callbacks, sd_models
|
||||
from modules.paths import models_path
|
||||
from modules import paths, shared, devices, script_callbacks, sd_models
|
||||
import glob
|
||||
from copy import deepcopy
|
||||
|
||||
|
||||
vae_path = os.path.abspath(os.path.join(models_path, "VAE"))
|
||||
vae_path = os.path.abspath(os.path.join(paths.models_path, "VAE"))
|
||||
vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"}
|
||||
vae_dict = {}
|
||||
|
||||
|
|
|
@ -13,17 +13,19 @@ import modules.interrogate
|
|||
import modules.memmon
|
||||
import modules.styles
|
||||
import modules.devices as devices
|
||||
from modules import localization, sd_vae, extensions, script_loading, errors, ui_components
|
||||
from modules.paths import models_path, script_path, sd_path
|
||||
from modules import localization, extensions, script_loading, errors, ui_components, shared_items
|
||||
from modules.paths import models_path, script_path, data_path
|
||||
|
||||
|
||||
demo = None
|
||||
|
||||
sd_default_config = os.path.join(script_path, "configs/v1-inference.yaml")
|
||||
sd_configs_path = os.path.join(script_path, "configs")
|
||||
sd_default_config = os.path.join(sd_configs_path, "v1-inference.yaml")
|
||||
sd_model_file = os.path.join(script_path, 'model.ckpt')
|
||||
default_sd_model_file = sd_model_file
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored",)
|
||||
parser.add_argument("--config", type=str, default=sd_default_config, help="path to config which constructs model",)
|
||||
parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",)
|
||||
parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints")
|
||||
|
@ -34,7 +36,7 @@ parser.add_argument("--no-half", action='store_true', help="do not switch the mo
|
|||
parser.add_argument("--no-half-vae", action='store_true', help="do not switch the VAE model to 16-bit floats")
|
||||
parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)")
|
||||
parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI")
|
||||
parser.add_argument("--embeddings-dir", type=str, default=os.path.join(script_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)")
|
||||
parser.add_argument("--embeddings-dir", type=str, default=os.path.join(data_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)")
|
||||
parser.add_argument("--textual-inversion-templates-dir", type=str, default=os.path.join(script_path, 'textual_inversion_templates'), help="directory with textual inversion templates")
|
||||
parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_path, 'hypernetworks'), help="hypernetwork directory")
|
||||
parser.add_argument("--localizations-dir", type=str, default=os.path.join(script_path, 'localizations'), help="localizations directory")
|
||||
|
@ -45,6 +47,7 @@ parser.add_argument("--lowram", action='store_true', help="load stable diffusion
|
|||
parser.add_argument("--always-batch-cond-uncond", action='store_true', help="disables cond/uncond batching that is enabled to save memory with --medvram or --lowvram")
|
||||
parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.")
|
||||
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
|
||||
parser.add_argument("--upcast-sampling", action='store_true', help="upcast sampling. No effect with --no-half. Usually produces similar results to --no-half with better performance while using less memory.")
|
||||
parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site")
|
||||
parser.add_argument("--ngrok", type=str, help="ngrok authtoken, alternative to gradio --share", default=None)
|
||||
parser.add_argument("--ngrok-region", type=str, help="The region in which ngrok should start.", default="us")
|
||||
|
@ -72,16 +75,16 @@ parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for sp
|
|||
parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
|
||||
parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None)
|
||||
parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False)
|
||||
parser.add_argument("--ui-config-file", type=str, help="filename to use for ui configuration", default=os.path.join(script_path, 'ui-config.json'))
|
||||
parser.add_argument("--ui-config-file", type=str, help="filename to use for ui configuration", default=os.path.join(data_path, 'ui-config.json'))
|
||||
parser.add_argument("--hide-ui-dir-config", action='store_true', help="hide directory configuration from webui", default=False)
|
||||
parser.add_argument("--freeze-settings", action='store_true', help="disable editing settings", default=False)
|
||||
parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default=os.path.join(script_path, 'config.json'))
|
||||
parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default=os.path.join(data_path, 'config.json'))
|
||||
parser.add_argument("--gradio-debug", action='store_true', help="launch gradio with --debug option")
|
||||
parser.add_argument("--gradio-auth", type=str, help='set gradio authentication like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
|
||||
parser.add_argument("--gradio-img2img-tool", type=str, help='does not do anything')
|
||||
parser.add_argument("--gradio-inpaint-tool", type=str, help="does not do anything")
|
||||
parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last")
|
||||
parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(script_path, 'styles.csv'))
|
||||
parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(data_path, 'styles.csv'))
|
||||
parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False)
|
||||
parser.add_argument("--theme", type=str, help="launches the UI with light or dark theme", default=None)
|
||||
parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False)
|
||||
|
@ -102,6 +105,8 @@ parser.add_argument("--tls-keyfile", type=str, help="Partially enables TLS, requ
|
|||
parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, requires --tls-keyfile to fully function", default=None)
|
||||
parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None)
|
||||
parser.add_argument("--gradio-queue", action='store_true', help="Uses gradio queue; experimental option; breaks restart UI button")
|
||||
parser.add_argument("--skip-version-check", action='store_true', help="Do not check versions of torch and xformers")
|
||||
parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False)
|
||||
|
||||
|
||||
script_loading.preload_extensions(extensions.extensions_dir, parser)
|
||||
|
@ -124,12 +129,13 @@ restricted_opts = {
|
|||
ui_reorder_categories = [
|
||||
"inpaint",
|
||||
"sampler",
|
||||
"checkboxes",
|
||||
"hires_fix",
|
||||
"dimensions",
|
||||
"cfg",
|
||||
"seed",
|
||||
"checkboxes",
|
||||
"hires_fix",
|
||||
"batch",
|
||||
"override_settings",
|
||||
"scripts",
|
||||
]
|
||||
|
||||
|
@ -263,12 +269,6 @@ interrogator = modules.interrogate.InterrogateModels("interrogate")
|
|||
|
||||
face_restorers = []
|
||||
|
||||
|
||||
def realesrgan_models_names():
|
||||
import modules.realesrgan_model
|
||||
return [x.name for x in modules.realesrgan_model.get_realesrgan_models(None)]
|
||||
|
||||
|
||||
class OptionInfo:
|
||||
def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, section=None, refresh=None):
|
||||
self.default = default
|
||||
|
@ -327,7 +327,7 @@ options_templates.update(options_section(('saving-images', "Saving images/grids"
|
|||
"jpeg_quality": OptionInfo(80, "Quality for saved jpeg images", gr.Slider, {"minimum": 1, "maximum": 100, "step": 1}),
|
||||
"export_for_4chan": OptionInfo(True, "If PNG image is larger than 4MB or any dimension is larger than 4000, downscale and save copy as JPG"),
|
||||
|
||||
"use_original_name_batch": OptionInfo(False, "Use original name for output filename during batch process in extras tab"),
|
||||
"use_original_name_batch": OptionInfo(True, "Use original name for output filename during batch process in extras tab"),
|
||||
"use_upscaler_name_as_suffix": OptionInfo(False, "Use upscaler name as filename suffix in the extras tab"),
|
||||
"save_selected_only": OptionInfo(True, "When using 'Save' button, only save a single selected image"),
|
||||
"do_not_add_watermark": OptionInfo(False, "Do not add watermark to images"),
|
||||
|
@ -349,17 +349,17 @@ options_templates.update(options_section(('saving-paths', "Paths for saving"), {
|
|||
}))
|
||||
|
||||
options_templates.update(options_section(('saving-to-dirs', "Saving to a directory"), {
|
||||
"save_to_dirs": OptionInfo(False, "Save images to a subdirectory"),
|
||||
"grid_save_to_dirs": OptionInfo(False, "Save grids to a subdirectory"),
|
||||
"save_to_dirs": OptionInfo(True, "Save images to a subdirectory"),
|
||||
"grid_save_to_dirs": OptionInfo(True, "Save grids to a subdirectory"),
|
||||
"use_save_to_dirs_for_ui": OptionInfo(False, "When using \"Save\" button, save images to a subdirectory"),
|
||||
"directories_filename_pattern": OptionInfo("", "Directory name pattern", component_args=hide_dirs),
|
||||
"directories_filename_pattern": OptionInfo("[date]", "Directory name pattern", component_args=hide_dirs),
|
||||
"directories_max_prompt_words": OptionInfo(8, "Max prompt words for [prompt_words] pattern", gr.Slider, {"minimum": 1, "maximum": 20, "step": 1, **hide_dirs}),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('upscaling', "Upscaling"), {
|
||||
"ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers. 0 = no tiling.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}),
|
||||
"ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for ESRGAN upscalers. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
|
||||
"realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": realesrgan_models_names()}),
|
||||
"realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": shared_items.realesrgan_models_names()}),
|
||||
"upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers]}),
|
||||
}))
|
||||
|
||||
|
@ -396,7 +396,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
|||
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints),
|
||||
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
||||
"sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
||||
"sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": ["Automatic", "None"] + list(sd_vae.vae_dict)}, refresh=sd_vae.refresh_vae_list),
|
||||
"sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list),
|
||||
"sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
|
||||
"inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
||||
"initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01}),
|
||||
|
@ -408,7 +408,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
|||
"enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
|
||||
"comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }),
|
||||
"CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}),
|
||||
"extra_networks_default_multiplier": OptionInfo(1.0, "Multiplier for extra networks", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
||||
"upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('compatibility', "Compatibility"), {
|
||||
|
@ -433,7 +433,9 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"),
|
|||
}))
|
||||
|
||||
options_templates.update(options_section(('extra_networks', "Extra Networks"), {
|
||||
"extra_networks_default_view": OptionInfo("cards", "Default view for Extra Networks", gr.Dropdown, { "choices": ["cards", "thumbs"] }),
|
||||
"extra_networks_default_view": OptionInfo("cards", "Default view for Extra Networks", gr.Dropdown, {"choices": ["cards", "thumbs"]}),
|
||||
"extra_networks_default_multiplier": OptionInfo(1.0, "Multiplier for extra networks", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
||||
"sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": [""] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('ui', "User interface"), {
|
||||
|
@ -441,7 +443,7 @@ options_templates.update(options_section(('ui', "User interface"), {
|
|||
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
|
||||
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
|
||||
"add_model_name_to_info": OptionInfo(True, "Add model name to generation information"),
|
||||
"disable_weights_auto_swap": OptionInfo(False, "When reading generation parameters from text into UI (from PNG info or pasted text), do not change the selected model/checkpoint."),
|
||||
"disable_weights_auto_swap": OptionInfo(True, "When reading generation parameters from text into UI (from PNG info or pasted text), do not change the selected model/checkpoint."),
|
||||
"send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"),
|
||||
"send_size": OptionInfo(True, "Send size when sending prompt or image to another interface"),
|
||||
"font": OptionInfo("", "Font for image grids that have text"),
|
||||
|
@ -481,7 +483,8 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters"
|
|||
}))
|
||||
|
||||
options_templates.update(options_section(('postprocessing', "Postprocessing"), {
|
||||
'postprocessing_scipts_order': OptionInfo("upscale, gfpgan, codeformer", "Postprocessing operation order"),
|
||||
'postprocessing_enable_in_main_ui': OptionInfo([], "Enable postprocessing operations in txt2img and img2img tabs", ui_components.DropdownMulti, lambda: {"choices": [x.name for x in shared_items.postprocessing_scripts()]}),
|
||||
'postprocessing_operation_order': OptionInfo([], "Postprocessing operation order", ui_components.DropdownMulti, lambda: {"choices": [x.name for x in shared_items.postprocessing_scripts()]}),
|
||||
'upscaling_max_images_in_cache': OptionInfo(5, "Maximum number of images in upscaling cache", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
||||
}))
|
||||
|
||||
|
@ -605,11 +608,37 @@ class Options:
|
|||
|
||||
self.data_labels = {k: v for k, v in sorted(settings_items, key=lambda x: section_ids[x[1].section])}
|
||||
|
||||
def cast_value(self, key, value):
|
||||
"""casts an arbitrary to the same type as this setting's value with key
|
||||
Example: cast_value("eta_noise_seed_delta", "12") -> returns 12 (an int rather than str)
|
||||
"""
|
||||
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
default_value = self.data_labels[key].default
|
||||
if default_value is None:
|
||||
default_value = getattr(self, key, None)
|
||||
if default_value is None:
|
||||
return None
|
||||
|
||||
expected_type = type(default_value)
|
||||
if expected_type == bool and value == "False":
|
||||
value = False
|
||||
else:
|
||||
value = expected_type(value)
|
||||
|
||||
return value
|
||||
|
||||
|
||||
|
||||
opts = Options()
|
||||
if os.path.exists(config_filename):
|
||||
opts.load(config_filename)
|
||||
|
||||
settings_components = None
|
||||
"""assinged from ui.py, a mapping on setting anmes to gradio components repsponsible for those settings"""
|
||||
|
||||
latent_upscale_default_mode = "Latent"
|
||||
latent_upscale_modes = {
|
||||
"Latent": {"mode": "bilinear", "antialias": False},
|
||||
|
|
23
modules/shared_items.py
Normal file
23
modules/shared_items.py
Normal file
|
@ -0,0 +1,23 @@
|
|||
|
||||
|
||||
def realesrgan_models_names():
|
||||
import modules.realesrgan_model
|
||||
return [x.name for x in modules.realesrgan_model.get_realesrgan_models(None)]
|
||||
|
||||
|
||||
def postprocessing_scripts():
|
||||
import modules.scripts
|
||||
|
||||
return modules.scripts.scripts_postproc.scripts
|
||||
|
||||
|
||||
def sd_vae_items():
|
||||
import modules.sd_vae
|
||||
|
||||
return ["Automatic", "None"] + list(modules.sd_vae.vae_dict)
|
||||
|
||||
|
||||
def refresh_vae_list():
|
||||
import modules.sd_vae
|
||||
|
||||
return modules.sd_vae.refresh_vae_list
|
|
@ -67,7 +67,7 @@ def _summarize_chunk(
|
|||
max_score, _ = torch.max(attn_weights, -1, keepdim=True)
|
||||
max_score = max_score.detach()
|
||||
exp_weights = torch.exp(attn_weights - max_score)
|
||||
exp_values = torch.bmm(exp_weights, value)
|
||||
exp_values = torch.bmm(exp_weights, value) if query.device.type == 'mps' else torch.bmm(exp_weights, value.to(exp_weights.dtype)).to(value.dtype)
|
||||
max_score = max_score.squeeze(-1)
|
||||
return AttnChunk(exp_values, exp_weights.sum(dim=-1), max_score)
|
||||
|
||||
|
@ -129,7 +129,7 @@ def _get_attention_scores_no_kv_chunking(
|
|||
)
|
||||
attn_probs = attn_scores.softmax(dim=-1)
|
||||
del attn_scores
|
||||
hidden_states_slice = torch.bmm(attn_probs, value)
|
||||
hidden_states_slice = torch.bmm(attn_probs, value) if query.device.type == 'mps' else torch.bmm(attn_probs, value.to(attn_probs.dtype)).to(value.dtype)
|
||||
return hidden_states_slice
|
||||
|
||||
|
||||
|
|
|
@ -6,8 +6,7 @@ import sys
|
|||
import tqdm
|
||||
import time
|
||||
|
||||
from modules import shared, images, deepbooru
|
||||
from modules.paths import models_path
|
||||
from modules import paths, shared, images, deepbooru
|
||||
from modules.shared import opts, cmd_opts
|
||||
from modules.textual_inversion import autocrop
|
||||
|
||||
|
@ -199,7 +198,7 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre
|
|||
|
||||
dnn_model_path = None
|
||||
try:
|
||||
dnn_model_path = autocrop.download_and_cache_models(os.path.join(models_path, "opencv"))
|
||||
dnn_model_path = autocrop.download_and_cache_models(os.path.join(paths.models_path, "opencv"))
|
||||
except Exception as e:
|
||||
print("Unable to load face detection model for auto crop selection. Falling back to lower quality haar method.", e)
|
||||
|
||||
|
|
|
@ -112,6 +112,7 @@ class EmbeddingDatabase:
|
|||
self.skipped_embeddings = {}
|
||||
self.expected_shape = -1
|
||||
self.embedding_dirs = {}
|
||||
self.previously_displayed_embeddings = ()
|
||||
|
||||
def add_embedding_dir(self, path):
|
||||
self.embedding_dirs[path] = DirWithTextualInversionEmbeddings(path)
|
||||
|
@ -194,7 +195,7 @@ class EmbeddingDatabase:
|
|||
if not os.path.isdir(embdir.path):
|
||||
return
|
||||
|
||||
for root, dirs, fns in os.walk(embdir.path):
|
||||
for root, dirs, fns in os.walk(embdir.path, followlinks=True):
|
||||
for fn in fns:
|
||||
try:
|
||||
fullfn = os.path.join(root, fn)
|
||||
|
@ -228,9 +229,12 @@ class EmbeddingDatabase:
|
|||
self.load_from_dir(embdir)
|
||||
embdir.update()
|
||||
|
||||
print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}")
|
||||
if len(self.skipped_embeddings) > 0:
|
||||
print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings.keys())}")
|
||||
displayed_embeddings = (tuple(self.word_embeddings.keys()), tuple(self.skipped_embeddings.keys()))
|
||||
if self.previously_displayed_embeddings != displayed_embeddings:
|
||||
self.previously_displayed_embeddings = displayed_embeddings
|
||||
print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}")
|
||||
if len(self.skipped_embeddings) > 0:
|
||||
print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings.keys())}")
|
||||
|
||||
def find_embedding_at_position(self, tokens, offset):
|
||||
token = tokens[offset]
|
||||
|
|
35
modules/timer.py
Normal file
35
modules/timer.py
Normal file
|
@ -0,0 +1,35 @@
|
|||
import time
|
||||
|
||||
|
||||
class Timer:
|
||||
def __init__(self):
|
||||
self.start = time.time()
|
||||
self.records = {}
|
||||
self.total = 0
|
||||
|
||||
def elapsed(self):
|
||||
end = time.time()
|
||||
res = end - self.start
|
||||
self.start = end
|
||||
return res
|
||||
|
||||
def record(self, category, extra_time=0):
|
||||
e = self.elapsed()
|
||||
if category not in self.records:
|
||||
self.records[category] = 0
|
||||
|
||||
self.records[category] += e + extra_time
|
||||
self.total += e + extra_time
|
||||
|
||||
def summary(self):
|
||||
res = f"{self.total:.1f}s"
|
||||
|
||||
additions = [x for x in self.records.items() if x[1] >= 0.1]
|
||||
if not additions:
|
||||
return res
|
||||
|
||||
res += " ("
|
||||
res += ", ".join([f"{category}: {time_taken:.1f}s" for category, time_taken in additions])
|
||||
res += ")"
|
||||
|
||||
return res
|
|
@ -1,5 +1,6 @@
|
|||
import modules.scripts
|
||||
from modules import sd_samplers
|
||||
from modules.generation_parameters_copypaste import create_override_settings_dict
|
||||
from modules.processing import StableDiffusionProcessing, Processed, StableDiffusionProcessingTxt2Img, \
|
||||
StableDiffusionProcessingImg2Img, process_images
|
||||
from modules.shared import opts, cmd_opts
|
||||
|
@ -8,7 +9,9 @@ import modules.processing as processing
|
|||
from modules.ui import plaintext_to_html
|
||||
|
||||
|
||||
def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, *args):
|
||||
def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, override_settings_texts, *args):
|
||||
override_settings = create_override_settings_dict(override_settings_texts)
|
||||
|
||||
p = StableDiffusionProcessingTxt2Img(
|
||||
sd_model=shared.sd_model,
|
||||
outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
|
||||
|
@ -38,6 +41,7 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, step
|
|||
hr_second_pass_steps=hr_second_pass_steps,
|
||||
hr_resize_x=hr_resize_x,
|
||||
hr_resize_y=hr_resize_y,
|
||||
override_settings=override_settings,
|
||||
)
|
||||
|
||||
p.scripts = modules.scripts.scripts_txt2img
|
||||
|
|
|
@ -21,7 +21,7 @@ from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_grad
|
|||
|
||||
from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, postprocessing, ui_components, ui_common, ui_postprocessing
|
||||
from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML
|
||||
from modules.paths import script_path
|
||||
from modules.paths import script_path, data_path
|
||||
|
||||
from modules.shared import opts, cmd_opts, restricted_opts
|
||||
|
||||
|
@ -91,6 +91,7 @@ save_style_symbol = '\U0001f4be' # 💾
|
|||
apply_style_symbol = '\U0001f4cb' # 📋
|
||||
clear_prompt_symbol = '\U0001F5D1' # 🗑️
|
||||
extra_networks_symbol = '\U0001F3B4' # 🎴
|
||||
switch_values_symbol = '\U000021C5' # ⇅
|
||||
|
||||
|
||||
def plaintext_to_html(text):
|
||||
|
@ -379,6 +380,7 @@ def apply_setting(key, value):
|
|||
opts.save(shared.config_filename)
|
||||
return getattr(opts, key)
|
||||
|
||||
|
||||
def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
|
||||
def refresh():
|
||||
refresh_method()
|
||||
|
@ -432,6 +434,18 @@ def get_value_for_setting(key):
|
|||
return gr.update(value=value, **args)
|
||||
|
||||
|
||||
def create_override_settings_dropdown(tabname, row):
|
||||
dropdown = gr.Dropdown([], label="Override settings", visible=False, elem_id=f"{tabname}_override_settings", multiselect=True)
|
||||
|
||||
dropdown.change(
|
||||
fn=lambda x: gr.Dropdown.update(visible=len(x) > 0),
|
||||
inputs=[dropdown],
|
||||
outputs=[dropdown],
|
||||
)
|
||||
|
||||
return dropdown
|
||||
|
||||
|
||||
def create_ui():
|
||||
import modules.img2img
|
||||
import modules.txt2img
|
||||
|
@ -465,6 +479,7 @@ def create_ui():
|
|||
width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="txt2img_width")
|
||||
height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="txt2img_height")
|
||||
|
||||
res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="txt2img_res_switch_btn")
|
||||
if opts.dimensions_and_batch_together:
|
||||
with gr.Column(elem_id="txt2img_column_batch"):
|
||||
batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count")
|
||||
|
@ -501,6 +516,10 @@ def create_ui():
|
|||
batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count")
|
||||
batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size")
|
||||
|
||||
elif category == "override_settings":
|
||||
with FormRow(elem_id="txt2img_override_settings_row") as row:
|
||||
override_settings = create_override_settings_dropdown('txt2img', row)
|
||||
|
||||
elif category == "scripts":
|
||||
with FormGroup(elem_id="txt2img_script_container"):
|
||||
custom_inputs = modules.scripts.scripts_txt2img.setup_ui()
|
||||
|
@ -522,7 +541,6 @@ def create_ui():
|
|||
)
|
||||
|
||||
txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples)
|
||||
parameters_copypaste.bind_buttons({"txt2img": txt2img_paste}, None, txt2img_prompt)
|
||||
|
||||
connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
|
||||
connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
|
||||
|
@ -553,6 +571,7 @@ def create_ui():
|
|||
hr_second_pass_steps,
|
||||
hr_resize_x,
|
||||
hr_resize_y,
|
||||
override_settings,
|
||||
] + custom_inputs,
|
||||
|
||||
outputs=[
|
||||
|
@ -567,6 +586,8 @@ def create_ui():
|
|||
txt2img_prompt.submit(**txt2img_args)
|
||||
submit.click(**txt2img_args)
|
||||
|
||||
res_switch_btn.click(lambda w, h: (h, w), inputs=[width, height], outputs=[width, height])
|
||||
|
||||
txt_prompt_img.change(
|
||||
fn=modules.images.image_data,
|
||||
inputs=[
|
||||
|
@ -611,6 +632,9 @@ def create_ui():
|
|||
*modules.scripts.scripts_txt2img.infotext_fields
|
||||
]
|
||||
parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields)
|
||||
parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(
|
||||
paste_button=txt2img_paste, tabname="txt2img", source_text_component=txt2img_prompt, source_image_component=None, override_settings_component=override_settings,
|
||||
))
|
||||
|
||||
txt2img_preview_params = [
|
||||
txt2img_prompt,
|
||||
|
@ -691,9 +715,15 @@ def create_ui():
|
|||
|
||||
with gr.TabItem('Batch', id='batch', elem_id="img2img_batch_tab") as tab_batch:
|
||||
hidden = '<br>Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else ''
|
||||
gr.HTML(f"<p style='padding-bottom: 1em;' class=\"text-gray-500\">Process images in a directory on the same machine where the server is running.<br>Use an empty output directory to save pictures normally instead of writing to the output directory.{hidden}</p>")
|
||||
gr.HTML(
|
||||
f"<p style='padding-bottom: 1em;' class=\"text-gray-500\">Process images in a directory on the same machine where the server is running." +
|
||||
f"<br>Use an empty output directory to save pictures normally instead of writing to the output directory." +
|
||||
f"<br>Add inpaint batch mask directory to enable inpaint batch processing."
|
||||
f"{hidden}</p>"
|
||||
)
|
||||
img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, elem_id="img2img_batch_input_dir")
|
||||
img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir")
|
||||
img2img_batch_inpaint_mask_dir = gr.Textbox(label="Inpaint batch mask directory (required for inpaint batch processing only)", **shared.hide_dirs, elem_id="img2img_batch_inpaint_mask_dir")
|
||||
|
||||
def copy_image(img):
|
||||
if isinstance(img, dict) and 'image' in img:
|
||||
|
@ -727,6 +757,7 @@ def create_ui():
|
|||
width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width")
|
||||
height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="img2img_height")
|
||||
|
||||
res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="img2img_res_switch_btn")
|
||||
if opts.dimensions_and_batch_together:
|
||||
with gr.Column(elem_id="img2img_column_batch"):
|
||||
batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count")
|
||||
|
@ -734,7 +765,9 @@ def create_ui():
|
|||
|
||||
elif category == "cfg":
|
||||
with FormGroup():
|
||||
cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="img2img_cfg_scale")
|
||||
with FormRow():
|
||||
cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="img2img_cfg_scale")
|
||||
image_cfg_scale = gr.Slider(minimum=0, maximum=3.0, step=0.05, label='Image CFG Scale', value=1.5, elem_id="img2img_image_cfg_scale", visible=shared.sd_model and shared.sd_model.cond_stage_key == "edit")
|
||||
denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75, elem_id="img2img_denoising_strength")
|
||||
|
||||
elif category == "seed":
|
||||
|
@ -751,6 +784,10 @@ def create_ui():
|
|||
batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count")
|
||||
batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size")
|
||||
|
||||
elif category == "override_settings":
|
||||
with FormRow(elem_id="img2img_override_settings_row") as row:
|
||||
override_settings = create_override_settings_dropdown('img2img', row)
|
||||
|
||||
elif category == "scripts":
|
||||
with FormGroup(elem_id="img2img_script_container"):
|
||||
custom_inputs = modules.scripts.scripts_img2img.setup_ui()
|
||||
|
@ -785,7 +822,6 @@ def create_ui():
|
|||
)
|
||||
|
||||
img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples)
|
||||
parameters_copypaste.bind_buttons({"img2img": img2img_paste}, None, img2img_prompt)
|
||||
|
||||
connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
|
||||
connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
|
||||
|
@ -827,6 +863,7 @@ def create_ui():
|
|||
batch_count,
|
||||
batch_size,
|
||||
cfg_scale,
|
||||
image_cfg_scale,
|
||||
denoising_strength,
|
||||
seed,
|
||||
subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox,
|
||||
|
@ -838,6 +875,8 @@ def create_ui():
|
|||
inpainting_mask_invert,
|
||||
img2img_batch_input_dir,
|
||||
img2img_batch_output_dir,
|
||||
img2img_batch_inpaint_mask_dir,
|
||||
override_settings,
|
||||
] + custom_inputs,
|
||||
outputs=[
|
||||
img2img_gallery,
|
||||
|
@ -865,6 +904,7 @@ def create_ui():
|
|||
|
||||
img2img_prompt.submit(**img2img_args)
|
||||
submit.click(**img2img_args)
|
||||
res_switch_btn.click(lambda w, h: (h, w), inputs=[width, height], outputs=[width, height])
|
||||
|
||||
img2img_interrogate.click(
|
||||
fn=lambda *args: process_interrogate(interrogate, *args),
|
||||
|
@ -910,6 +950,7 @@ def create_ui():
|
|||
(sampler_index, "Sampler"),
|
||||
(restore_faces, "Face restoration"),
|
||||
(cfg_scale, "CFG scale"),
|
||||
(image_cfg_scale, "Image CFG scale"),
|
||||
(seed, "Seed"),
|
||||
(width, "Size-1"),
|
||||
(height, "Size-2"),
|
||||
|
@ -924,6 +965,9 @@ def create_ui():
|
|||
]
|
||||
parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields)
|
||||
parameters_copypaste.add_paste_fields("inpaint", init_img_with_mask, img2img_paste_fields)
|
||||
parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(
|
||||
paste_button=img2img_paste, tabname="img2img", source_text_component=img2img_prompt, source_image_component=None, override_settings_component=override_settings,
|
||||
))
|
||||
|
||||
modules.scripts.scripts_current = None
|
||||
|
||||
|
@ -941,7 +985,11 @@ def create_ui():
|
|||
html2 = gr.HTML()
|
||||
with gr.Row():
|
||||
buttons = parameters_copypaste.create_buttons(["txt2img", "img2img", "inpaint", "extras"])
|
||||
parameters_copypaste.bind_buttons(buttons, image, generation_info)
|
||||
|
||||
for tabname, button in buttons.items():
|
||||
parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(
|
||||
paste_button=button, tabname=tabname, source_text_component=generation_info, source_image_component=image,
|
||||
))
|
||||
|
||||
image.change(
|
||||
fn=wrap_gradio_call(modules.extras.run_pnginfo),
|
||||
|
@ -1350,6 +1398,7 @@ def create_ui():
|
|||
|
||||
components = []
|
||||
component_dict = {}
|
||||
shared.settings_components = component_dict
|
||||
|
||||
script_callbacks.ui_settings_callback()
|
||||
opts.reorder()
|
||||
|
@ -1497,8 +1546,8 @@ def create_ui():
|
|||
with open(cssfile, "r", encoding="utf8") as file:
|
||||
css += file.read() + "\n"
|
||||
|
||||
if os.path.exists(os.path.join(script_path, "user.css")):
|
||||
with open(os.path.join(script_path, "user.css"), "r", encoding="utf8") as file:
|
||||
if os.path.exists(os.path.join(data_path, "user.css")):
|
||||
with open(os.path.join(data_path, "user.css"), "r", encoding="utf8") as file:
|
||||
css += file.read() + "\n"
|
||||
|
||||
if not cmd_opts.no_progressbar_hiding:
|
||||
|
@ -1516,8 +1565,7 @@ def create_ui():
|
|||
component = create_setting_component(k, is_quicksettings=True)
|
||||
component_dict[k] = component
|
||||
|
||||
parameters_copypaste.integrate_settings_paste_fields(component_dict)
|
||||
parameters_copypaste.run_bind()
|
||||
parameters_copypaste.connect_paste_params_buttons()
|
||||
|
||||
with gr.Tabs(elem_id="tabs") as tabs:
|
||||
for interface, label, ifid in interfaces:
|
||||
|
@ -1547,6 +1595,20 @@ def create_ui():
|
|||
outputs=[component, text_settings],
|
||||
)
|
||||
|
||||
text_settings.change(
|
||||
fn=lambda: gr.update(visible=shared.sd_model and shared.sd_model.cond_stage_key == "edit"),
|
||||
inputs=[],
|
||||
outputs=[image_cfg_scale],
|
||||
)
|
||||
|
||||
button_set_checkpoint = gr.Button('Change checkpoint', elem_id='change_checkpoint', visible=False)
|
||||
button_set_checkpoint.click(
|
||||
fn=lambda value, _: run_settings_single(value, key='sd_model_checkpoint'),
|
||||
_js="function(v){ var res = desiredCheckpointName; desiredCheckpointName = ''; return [res || v, null]; }",
|
||||
inputs=[component_dict['sd_model_checkpoint'], dummy_component],
|
||||
outputs=[component_dict['sd_model_checkpoint'], text_settings],
|
||||
)
|
||||
|
||||
component_keys = [k for k in opts.data_labels.keys() if k in component_dict]
|
||||
|
||||
def get_settings_values():
|
||||
|
@ -1679,14 +1741,14 @@ def create_ui():
|
|||
|
||||
|
||||
def reload_javascript():
|
||||
head = f'<script type="text/javascript" src="file={os.path.abspath("script.js")}"></script>\n'
|
||||
head = f'<script type="text/javascript" src="file={os.path.abspath("script.js")}?{os.path.getmtime("script.js")}"></script>\n'
|
||||
|
||||
inline = f"{localization.localization_js(shared.opts.localization)};"
|
||||
if cmd_opts.theme is not None:
|
||||
inline += f"set_theme('{cmd_opts.theme}');"
|
||||
|
||||
for script in modules.scripts.list_scripts("javascript", ".js"):
|
||||
head += f'<script type="text/javascript" src="file={script.path}"></script>\n'
|
||||
head += f'<script type="text/javascript" src="file={script.path}?{os.path.getmtime(script.path)}"></script>\n'
|
||||
|
||||
head += f'<script type="text/javascript">{inline}</script>\n'
|
||||
|
||||
|
|
|
@ -198,5 +198,9 @@ Requested path was: {f}
|
|||
html_info = gr.HTML(elem_id=f'html_info_{tabname}')
|
||||
html_log = gr.HTML(elem_id=f'html_log_{tabname}')
|
||||
|
||||
parameters_copypaste.bind_buttons(buttons, result_gallery, "txt2img" if tabname == "txt2img" else None)
|
||||
for paste_tabname, paste_button in buttons.items():
|
||||
parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(
|
||||
paste_button=paste_button, tabname=paste_tabname, source_tabname="txt2img" if tabname == "txt2img" else None, source_image_component=result_gallery
|
||||
))
|
||||
|
||||
return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log
|
||||
|
|
|
@ -48,3 +48,11 @@ class FormColorPicker(gr.ColorPicker, gr.components.FormComponent):
|
|||
def get_block_name(self):
|
||||
return "colorpicker"
|
||||
|
||||
|
||||
class DropdownMulti(gr.Dropdown):
|
||||
"""Same as gr.Dropdown but always multiselect"""
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(multiselect=True, **kwargs)
|
||||
|
||||
def get_block_name(self):
|
||||
return "dropdown"
|
||||
|
|
|
@ -13,7 +13,7 @@ import shutil
|
|||
import errno
|
||||
|
||||
from modules import extensions, shared, paths
|
||||
|
||||
from modules.call_queue import wrap_gradio_gpu_call
|
||||
|
||||
available_extensions = {"extensions": []}
|
||||
|
||||
|
@ -50,12 +50,17 @@ def apply_and_restart(disable_list, update_list):
|
|||
shared.state.need_restart = True
|
||||
|
||||
|
||||
def check_updates():
|
||||
def check_updates(id_task, disable_list):
|
||||
check_access()
|
||||
|
||||
for ext in extensions.extensions:
|
||||
if ext.remote is None:
|
||||
continue
|
||||
disabled = json.loads(disable_list)
|
||||
assert type(disabled) == list, f"wrong disable_list data for apply_and_restart: {disable_list}"
|
||||
|
||||
exts = [ext for ext in extensions.extensions if ext.remote is not None and ext.name not in disabled]
|
||||
shared.state.job_count = len(exts)
|
||||
|
||||
for ext in exts:
|
||||
shared.state.textinfo = ext.name
|
||||
|
||||
try:
|
||||
ext.check_updates()
|
||||
|
@ -63,7 +68,9 @@ def check_updates():
|
|||
print(f"Error checking updates for {ext.name}:", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
return extension_table()
|
||||
shared.state.nextjob()
|
||||
|
||||
return extension_table(), ""
|
||||
|
||||
|
||||
def extension_table():
|
||||
|
@ -132,7 +139,7 @@ def install_extension_from_url(dirname, url):
|
|||
normalized_url = normalize_git_url(url)
|
||||
assert len([x for x in extensions.extensions if normalize_git_url(x.remote) == normalized_url]) == 0, 'Extension with this URL is already installed'
|
||||
|
||||
tmpdir = os.path.join(paths.script_path, "tmp", dirname)
|
||||
tmpdir = os.path.join(paths.data_path, "tmp", dirname)
|
||||
|
||||
try:
|
||||
shutil.rmtree(tmpdir, True)
|
||||
|
@ -273,12 +280,13 @@ def create_ui():
|
|||
with gr.Tabs(elem_id="tabs_extensions") as tabs:
|
||||
with gr.TabItem("Installed"):
|
||||
|
||||
with gr.Row():
|
||||
with gr.Row(elem_id="extensions_installed_top"):
|
||||
apply = gr.Button(value="Apply and restart UI", variant="primary")
|
||||
check = gr.Button(value="Check for updates")
|
||||
extensions_disabled_list = gr.Text(elem_id="extensions_disabled_list", visible=False).style(container=False)
|
||||
extensions_update_list = gr.Text(elem_id="extensions_update_list", visible=False).style(container=False)
|
||||
|
||||
info = gr.HTML()
|
||||
extensions_table = gr.HTML(lambda: extension_table())
|
||||
|
||||
apply.click(
|
||||
|
@ -289,10 +297,10 @@ def create_ui():
|
|||
)
|
||||
|
||||
check.click(
|
||||
fn=check_updates,
|
||||
fn=wrap_gradio_gpu_call(check_updates, extra_outputs=[gr.update()]),
|
||||
_js="extensions_check",
|
||||
inputs=[],
|
||||
outputs=[extensions_table],
|
||||
inputs=[info, extensions_disabled_list],
|
||||
outputs=[extensions_table, info],
|
||||
)
|
||||
|
||||
with gr.TabItem("Available"):
|
||||
|
|
|
@ -1,4 +1,7 @@
|
|||
import glob
|
||||
import os.path
|
||||
import urllib.parse
|
||||
from pathlib import Path
|
||||
|
||||
from modules import shared
|
||||
import gradio as gr
|
||||
|
@ -8,12 +11,32 @@ import html
|
|||
from modules.generation_parameters_copypaste import image_from_url_text
|
||||
|
||||
extra_pages = []
|
||||
allowed_dirs = set()
|
||||
|
||||
|
||||
def register_page(page):
|
||||
"""registers extra networks page for the UI; recommend doing it in on_before_ui() callback for extensions"""
|
||||
|
||||
extra_pages.append(page)
|
||||
allowed_dirs.clear()
|
||||
allowed_dirs.update(set(sum([x.allowed_directories_for_previews() for x in extra_pages], [])))
|
||||
|
||||
|
||||
def add_pages_to_demo(app):
|
||||
def fetch_file(filename: str = ""):
|
||||
from starlette.responses import FileResponse
|
||||
|
||||
if not any([Path(x).absolute() in Path(filename).absolute().parents for x in allowed_dirs]):
|
||||
raise ValueError(f"File cannot be fetched: {filename}. Must be in one of directories registered by extra pages.")
|
||||
|
||||
ext = os.path.splitext(filename)[1].lower()
|
||||
if ext not in (".png", ".jpg"):
|
||||
raise ValueError(f"File cannot be fetched: {filename}. Only png and jpg.")
|
||||
|
||||
# would profit from returning 304
|
||||
return FileResponse(filename, headers={"Accept-Ranges": "bytes"})
|
||||
|
||||
app.add_api_route("/sd_extra_networks/thumb", fetch_file, methods=["GET"])
|
||||
|
||||
|
||||
class ExtraNetworksPage:
|
||||
|
@ -26,10 +49,44 @@ class ExtraNetworksPage:
|
|||
def refresh(self):
|
||||
pass
|
||||
|
||||
def link_preview(self, filename):
|
||||
return "./sd_extra_networks/thumb?filename=" + urllib.parse.quote(filename.replace('\\', '/')) + "&mtime=" + str(os.path.getmtime(filename))
|
||||
|
||||
def search_terms_from_path(self, filename, possible_directories=None):
|
||||
abspath = os.path.abspath(filename)
|
||||
|
||||
for parentdir in (possible_directories if possible_directories is not None else self.allowed_directories_for_previews()):
|
||||
parentdir = os.path.abspath(parentdir)
|
||||
if abspath.startswith(parentdir):
|
||||
return abspath[len(parentdir):].replace('\\', '/')
|
||||
|
||||
return ""
|
||||
|
||||
def create_html(self, tabname):
|
||||
view = shared.opts.extra_networks_default_view
|
||||
items_html = ''
|
||||
|
||||
subdirs = {}
|
||||
for parentdir in [os.path.abspath(x) for x in self.allowed_directories_for_previews()]:
|
||||
for x in glob.glob(os.path.join(parentdir, '**/*'), recursive=True):
|
||||
if not os.path.isdir(x):
|
||||
continue
|
||||
|
||||
subdir = os.path.abspath(x)[len(parentdir):].replace("\\", "/")
|
||||
while subdir.startswith("/"):
|
||||
subdir = subdir[1:]
|
||||
|
||||
subdirs[subdir] = 1
|
||||
|
||||
if subdirs:
|
||||
subdirs = {"": 1, **subdirs}
|
||||
|
||||
subdirs_html = "".join([f"""
|
||||
<button class='gr-button gr-button-lg gr-button-secondary{" search-all" if subdir=="" else ""}' onclick='extraNetworksSearchButton("{tabname}_extra_tabs", event)'>
|
||||
{html.escape(subdir if subdir!="" else "all")}
|
||||
</button>
|
||||
""" for subdir in subdirs])
|
||||
|
||||
for item in self.list_items():
|
||||
items_html += self.create_html_for_item(item, tabname)
|
||||
|
||||
|
@ -38,6 +95,9 @@ class ExtraNetworksPage:
|
|||
items_html = shared.html("extra-networks-no-cards.html").format(dirs=dirs)
|
||||
|
||||
res = f"""
|
||||
<div id='{tabname}_{self.name}_subdirs' class='extra-network-subdirs extra-network-subdirs-{view}'>
|
||||
{subdirs_html}
|
||||
</div>
|
||||
<div id='{tabname}_{self.name}_cards' class='extra-network-{view}'>
|
||||
{items_html}
|
||||
</div>
|
||||
|
@ -54,14 +114,19 @@ class ExtraNetworksPage:
|
|||
def create_html_for_item(self, item, tabname):
|
||||
preview = item.get("preview", None)
|
||||
|
||||
onclick = item.get("onclick", None)
|
||||
if onclick is None:
|
||||
onclick = '"' + html.escape(f"""return cardClicked({json.dumps(tabname)}, {item["prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"'
|
||||
|
||||
args = {
|
||||
"preview_html": "style='background-image: url(\"" + html.escape(preview) + "\")'" if preview else '',
|
||||
"prompt": item["prompt"],
|
||||
"prompt": item.get("prompt", None),
|
||||
"tabname": json.dumps(tabname),
|
||||
"local_preview": json.dumps(item["local_preview"]),
|
||||
"name": item["name"],
|
||||
"card_clicked": '"' + html.escape(f"""return cardClicked({json.dumps(tabname)}, {item["prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"',
|
||||
"card_clicked": onclick,
|
||||
"save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {json.dumps(tabname)}, {json.dumps(item["local_preview"])})""") + '"',
|
||||
"search_term": item.get("search_term", ""),
|
||||
}
|
||||
|
||||
return self.card_page.format(**args)
|
||||
|
@ -117,8 +182,13 @@ def create_ui(container, button, tabname):
|
|||
ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False)
|
||||
ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=tabname+"_preview_filename", visible=False)
|
||||
|
||||
button.click(fn=lambda: gr.update(visible=True), inputs=[], outputs=[container])
|
||||
button_close.click(fn=lambda: gr.update(visible=False), inputs=[], outputs=[container])
|
||||
def toggle_visibility(is_visible):
|
||||
is_visible = not is_visible
|
||||
return is_visible, gr.update(visible=is_visible)
|
||||
|
||||
state_visible = gr.State(value=False)
|
||||
button.click(fn=toggle_visibility, inputs=[state_visible], outputs=[state_visible, container])
|
||||
button_close.click(fn=toggle_visibility, inputs=[state_visible], outputs=[state_visible, container])
|
||||
|
||||
def refresh():
|
||||
res = []
|
||||
|
@ -138,7 +208,7 @@ def path_is_parent(parent_path, child_path):
|
|||
parent_path = os.path.abspath(parent_path)
|
||||
child_path = os.path.abspath(child_path)
|
||||
|
||||
return os.path.commonpath([parent_path]) == os.path.commonpath([parent_path, child_path])
|
||||
return child_path.startswith(parent_path)
|
||||
|
||||
|
||||
def setup_ui(ui, gallery):
|
||||
|
@ -168,7 +238,8 @@ def setup_ui(ui, gallery):
|
|||
|
||||
ui.button_save_preview.click(
|
||||
fn=save_preview,
|
||||
_js="function(x, y, z){console.log(x, y, z); return [selected_gallery_index(), y, z]}",
|
||||
_js="function(x, y, z){return [selected_gallery_index(), y, z]}",
|
||||
inputs=[ui.preview_target_filename, gallery, ui.preview_target_filename],
|
||||
outputs=[*ui.pages]
|
||||
)
|
||||
|
||||
|
|
39
modules/ui_extra_networks_checkpoints.py
Normal file
39
modules/ui_extra_networks_checkpoints.py
Normal file
|
@ -0,0 +1,39 @@
|
|||
import html
|
||||
import json
|
||||
import os
|
||||
import urllib.parse
|
||||
|
||||
from modules import shared, ui_extra_networks, sd_models
|
||||
|
||||
|
||||
class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
|
||||
def __init__(self):
|
||||
super().__init__('Checkpoints')
|
||||
|
||||
def refresh(self):
|
||||
shared.refresh_checkpoints()
|
||||
|
||||
def list_items(self):
|
||||
checkpoint: sd_models.CheckpointInfo
|
||||
for name, checkpoint in sd_models.checkpoints_list.items():
|
||||
path, ext = os.path.splitext(checkpoint.filename)
|
||||
previews = [path + ".png", path + ".preview.png"]
|
||||
|
||||
preview = None
|
||||
for file in previews:
|
||||
if os.path.isfile(file):
|
||||
preview = self.link_preview(file)
|
||||
break
|
||||
|
||||
yield {
|
||||
"name": checkpoint.name_for_extra,
|
||||
"filename": path,
|
||||
"preview": preview,
|
||||
"search_term": self.search_terms_from_path(checkpoint.filename) + " " + (checkpoint.sha256 or ""),
|
||||
"onclick": '"' + html.escape(f"""return selectCheckpoint({json.dumps(name)})""") + '"',
|
||||
"local_preview": path + ".png",
|
||||
}
|
||||
|
||||
def allowed_directories_for_previews(self):
|
||||
return [v for v in [shared.cmd_opts.ckpt_dir, sd_models.model_path] if v is not None]
|
||||
|
|
@ -19,13 +19,14 @@ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage):
|
|||
preview = None
|
||||
for file in previews:
|
||||
if os.path.isfile(file):
|
||||
preview = "./file=" + file.replace('\\', '/') + "?mtime=" + str(os.path.getmtime(file))
|
||||
preview = self.link_preview(file)
|
||||
break
|
||||
|
||||
yield {
|
||||
"name": name,
|
||||
"filename": path,
|
||||
"preview": preview,
|
||||
"search_term": self.search_terms_from_path(path),
|
||||
"prompt": json.dumps(f"<hypernet:{name}:") + " + opts.extra_networks_default_multiplier + " + json.dumps(">"),
|
||||
"local_preview": path + ".png",
|
||||
}
|
||||
|
|
|
@ -19,12 +19,13 @@ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage):
|
|||
|
||||
preview = None
|
||||
if os.path.isfile(preview_file):
|
||||
preview = "./file=" + preview_file.replace('\\', '/') + "?mtime=" + str(os.path.getmtime(preview_file))
|
||||
preview = self.link_preview(preview_file)
|
||||
|
||||
yield {
|
||||
"name": embedding.name,
|
||||
"filename": embedding.filename,
|
||||
"preview": preview,
|
||||
"search_term": self.search_terms_from_path(embedding.filename),
|
||||
"prompt": json.dumps(embedding.name),
|
||||
"local_preview": path + ".preview.png",
|
||||
}
|
||||
|
|
|
@ -11,7 +11,6 @@ from modules import modelloader, shared
|
|||
|
||||
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
|
||||
NEAREST = (Image.Resampling.NEAREST if hasattr(Image, 'Resampling') else Image.NEAREST)
|
||||
from modules.paths import models_path
|
||||
|
||||
|
||||
class Upscaler:
|
||||
|
@ -39,7 +38,7 @@ class Upscaler:
|
|||
self.mod_scale = None
|
||||
|
||||
if self.model_path is None and self.name:
|
||||
self.model_path = os.path.join(models_path, self.name)
|
||||
self.model_path = os.path.join(shared.models_path, self.name)
|
||||
if self.model_path and create_dirs:
|
||||
os.makedirs(self.model_path, exist_ok=True)
|
||||
|
||||
|
@ -143,4 +142,4 @@ class UpscalerNearest(Upscaler):
|
|||
def __init__(self, dirname=None):
|
||||
super().__init__(False)
|
||||
self.name = "Nearest"
|
||||
self.scalers = [UpscalerData("Nearest", None, self)]
|
||||
self.scalers = [UpscalerData("Nearest", None, self)]
|
||||
|
|
|
@ -16,7 +16,7 @@ pytorch_lightning==1.7.7
|
|||
realesrgan
|
||||
scikit-image>=0.19
|
||||
timm==0.4.12
|
||||
transformers==4.19.2
|
||||
transformers==4.25.1
|
||||
torch
|
||||
einops
|
||||
jsonmerge
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
blendmodes==2022
|
||||
transformers==4.19.2
|
||||
transformers==4.25.1
|
||||
accelerate==0.12.0
|
||||
basicsr==1.4.2
|
||||
gfpgan==1.3.8
|
||||
|
|
|
@ -6,7 +6,7 @@ from tqdm import trange
|
|||
import modules.scripts as scripts
|
||||
import gradio as gr
|
||||
|
||||
from modules import processing, shared, sd_samplers, prompt_parser
|
||||
from modules import processing, shared, sd_samplers, prompt_parser, sd_samplers_common
|
||||
from modules.processing import Processed
|
||||
from modules.shared import opts, cmd_opts, state
|
||||
|
||||
|
@ -50,7 +50,7 @@ def find_noise_for_image(p, cond, uncond, cfg_scale, steps):
|
|||
|
||||
x = x + d * dt
|
||||
|
||||
sd_samplers.store_latent(x)
|
||||
sd_samplers_common.store_latent(x)
|
||||
|
||||
# This shouldn't be necessary, but solved some VRAM issues
|
||||
del x_in, sigma_in, cond_in, c_out, c_in, t,
|
||||
|
@ -104,7 +104,7 @@ def find_noise_for_image_sigma_adjustment(p, cond, uncond, cfg_scale, steps):
|
|||
dt = sigmas[i] - sigmas[i - 1]
|
||||
x = x + d * dt
|
||||
|
||||
sd_samplers.store_latent(x)
|
||||
sd_samplers_common.store_latent(x)
|
||||
|
||||
# This shouldn't be necessary, but solved some VRAM issues
|
||||
del x_in, sigma_in, cond_in, c_out, c_in, t,
|
||||
|
|
|
@ -104,3 +104,28 @@ class ScriptPostprocessingUpscale(scripts_postprocessing.ScriptPostprocessing):
|
|||
|
||||
def image_changed(self):
|
||||
upscale_cache.clear()
|
||||
|
||||
|
||||
class ScriptPostprocessingUpscaleSimple(ScriptPostprocessingUpscale):
|
||||
name = "Simple Upscale"
|
||||
order = 900
|
||||
|
||||
def ui(self):
|
||||
with FormRow():
|
||||
upscaler_name = gr.Dropdown(label='Upscaler', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name)
|
||||
upscale_by = gr.Slider(minimum=0.05, maximum=8.0, step=0.05, label="Upscale by", value=2)
|
||||
|
||||
return {
|
||||
"upscale_by": upscale_by,
|
||||
"upscaler_name": upscaler_name,
|
||||
}
|
||||
|
||||
def process(self, pp: scripts_postprocessing.PostprocessedImage, upscale_by=2.0, upscaler_name=None):
|
||||
if upscaler_name is None or upscaler_name == "None":
|
||||
return
|
||||
|
||||
upscaler1 = next(iter([x for x in shared.sd_upscalers if x.name == upscaler_name]), None)
|
||||
assert upscaler1, f'could not find upscaler named {upscaler_name}'
|
||||
|
||||
pp.image = self.upscale(pp.image, pp.info, upscaler1, 0, upscale_by, 0, 0, False)
|
||||
pp.info[f"Postprocess upscaler"] = upscaler1.name
|
||||
|
|
|
@ -44,16 +44,34 @@ class Script(scripts.Script):
|
|||
def title(self):
|
||||
return "Prompt matrix"
|
||||
|
||||
def ui(self, is_img2img):
|
||||
put_at_start = gr.Checkbox(label='Put variable parts at start of prompt', value=False, elem_id=self.elem_id("put_at_start"))
|
||||
different_seeds = gr.Checkbox(label='Use different seed for each picture', value=False, elem_id=self.elem_id("different_seeds"))
|
||||
def ui(self, is_img2img):
|
||||
gr.HTML('<br />')
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
put_at_start = gr.Checkbox(label='Put variable parts at start of prompt', value=False, elem_id=self.elem_id("put_at_start"))
|
||||
different_seeds = gr.Checkbox(label='Use different seed for each picture', value=False, elem_id=self.elem_id("different_seeds"))
|
||||
with gr.Column():
|
||||
prompt_type = gr.Radio(["positive", "negative"], label="Select prompt", elem_id=self.elem_id("prompt_type"), value="positive")
|
||||
variations_delimiter = gr.Radio(["comma", "space"], label="Select joining char", elem_id=self.elem_id("variations_delimiter"), value="comma")
|
||||
with gr.Column():
|
||||
margin_size = gr.Slider(label="Grid margins (px)", min=0, max=500, value=0, step=2, elem_id=self.elem_id("margin_size"))
|
||||
|
||||
return [put_at_start, different_seeds]
|
||||
return [put_at_start, different_seeds, prompt_type, variations_delimiter, margin_size]
|
||||
|
||||
def run(self, p, put_at_start, different_seeds):
|
||||
def run(self, p, put_at_start, different_seeds, prompt_type, variations_delimiter, margin_size):
|
||||
modules.processing.fix_seed(p)
|
||||
# Raise error if promp type is not positive or negative
|
||||
if prompt_type not in ["positive", "negative"]:
|
||||
raise ValueError(f"Unknown prompt type {prompt_type}")
|
||||
# Raise error if variations delimiter is not comma or space
|
||||
if variations_delimiter not in ["comma", "space"]:
|
||||
raise ValueError(f"Unknown variations delimiter {variations_delimiter}")
|
||||
|
||||
original_prompt = p.prompt[0] if type(p.prompt) == list else p.prompt
|
||||
prompt = p.prompt if prompt_type == "positive" else p.negative_prompt
|
||||
original_prompt = prompt[0] if type(prompt) == list else prompt
|
||||
positive_prompt = p.prompt[0] if type(p.prompt) == list else p.prompt
|
||||
|
||||
delimiter = ", " if variations_delimiter == "comma" else " "
|
||||
|
||||
all_prompts = []
|
||||
prompt_matrix_parts = original_prompt.split("|")
|
||||
|
@ -66,20 +84,23 @@ class Script(scripts.Script):
|
|||
else:
|
||||
selected_prompts = [prompt_matrix_parts[0]] + selected_prompts
|
||||
|
||||
all_prompts.append(", ".join(selected_prompts))
|
||||
all_prompts.append(delimiter.join(selected_prompts))
|
||||
|
||||
p.n_iter = math.ceil(len(all_prompts) / p.batch_size)
|
||||
p.do_not_save_grid = True
|
||||
|
||||
print(f"Prompt matrix will create {len(all_prompts)} images using a total of {p.n_iter} batches.")
|
||||
|
||||
p.prompt = all_prompts
|
||||
if prompt_type == "positive":
|
||||
p.prompt = all_prompts
|
||||
else:
|
||||
p.negative_prompt = all_prompts
|
||||
p.seed = [p.seed + (i if different_seeds else 0) for i in range(len(all_prompts))]
|
||||
p.prompt_for_display = original_prompt
|
||||
p.prompt_for_display = positive_prompt
|
||||
processed = process_images(p)
|
||||
|
||||
grid = images.image_grid(processed.images, p.batch_size, rows=1 << ((len(prompt_matrix_parts) - 1) // 2))
|
||||
grid = images.draw_prompt_matrix(grid, p.width, p.height, prompt_matrix_parts)
|
||||
grid = images.draw_prompt_matrix(grid, p.width, p.height, prompt_matrix_parts, margin_size)
|
||||
processed.images.insert(0, grid)
|
||||
processed.index_of_first_image = 1
|
||||
processed.infotexts.insert(0, processed.infotexts[0])
|
||||
|
|
|
@ -123,7 +123,7 @@ def apply_vae(p, x, xs):
|
|||
|
||||
|
||||
def apply_styles(p: StableDiffusionProcessingTxt2Img, x: str, _):
|
||||
p.styles = x.split(',')
|
||||
p.styles.extend(x.split(','))
|
||||
|
||||
|
||||
def format_value_add_label(p, opt, x):
|
||||
|
@ -205,26 +205,30 @@ axis_options = [
|
|||
]
|
||||
|
||||
|
||||
def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend, include_lone_images, swap_axes_processing_order):
|
||||
ver_texts = [[images.GridAnnotation(y)] for y in y_labels]
|
||||
def draw_xyz_grid(p, xs, ys, zs, x_labels, y_labels, z_labels, cell, draw_legend, include_lone_images, include_sub_grids, first_axes_processed, second_axes_processed, margin_size):
|
||||
hor_texts = [[images.GridAnnotation(x)] for x in x_labels]
|
||||
ver_texts = [[images.GridAnnotation(y)] for y in y_labels]
|
||||
title_texts = [[images.GridAnnotation(z)] for z in z_labels]
|
||||
|
||||
# Temporary list of all the images that are generated to be populated into the grid.
|
||||
# Will be filled with empty images for any individual step that fails to process properly
|
||||
image_cache = [None] * (len(xs) * len(ys))
|
||||
image_cache = [None] * (len(xs) * len(ys) * len(zs))
|
||||
|
||||
processed_result = None
|
||||
cell_mode = "P"
|
||||
cell_size = (1, 1)
|
||||
|
||||
state.job_count = len(xs) * len(ys) * p.n_iter
|
||||
state.job_count = len(xs) * len(ys) * len(zs) * p.n_iter
|
||||
|
||||
def process_cell(x, y, ix, iy):
|
||||
def process_cell(x, y, z, ix, iy, iz):
|
||||
nonlocal image_cache, processed_result, cell_mode, cell_size
|
||||
|
||||
state.job = f"{ix + iy * len(xs) + 1} out of {len(xs) * len(ys)}"
|
||||
def index(ix, iy, iz):
|
||||
return ix + iy * len(xs) + iz * len(xs) * len(ys)
|
||||
|
||||
processed: Processed = cell(x, y)
|
||||
state.job = f"{index(ix, iy, iz) + 1} out of {len(xs) * len(ys) * len(zs)}"
|
||||
|
||||
processed: Processed = cell(x, y, z)
|
||||
|
||||
try:
|
||||
# this dereference will throw an exception if the image was not processed
|
||||
|
@ -238,35 +242,68 @@ def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend, include_lone_
|
|||
cell_size = processed_image.size
|
||||
processed_result.images = [Image.new(cell_mode, cell_size)]
|
||||
|
||||
image_cache[ix + iy * len(xs)] = processed_image
|
||||
image_cache[index(ix, iy, iz)] = processed_image
|
||||
if include_lone_images:
|
||||
processed_result.images.append(processed_image)
|
||||
processed_result.all_prompts.append(processed.prompt)
|
||||
processed_result.all_seeds.append(processed.seed)
|
||||
processed_result.infotexts.append(processed.infotexts[0])
|
||||
except:
|
||||
image_cache[ix + iy * len(xs)] = Image.new(cell_mode, cell_size)
|
||||
image_cache[index(ix, iy, iz)] = Image.new(cell_mode, cell_size)
|
||||
|
||||
if swap_axes_processing_order:
|
||||
if first_axes_processed == 'x':
|
||||
for ix, x in enumerate(xs):
|
||||
for iy, y in enumerate(ys):
|
||||
process_cell(x, y, ix, iy)
|
||||
else:
|
||||
if second_axes_processed == 'y':
|
||||
for iy, y in enumerate(ys):
|
||||
for iz, z in enumerate(zs):
|
||||
process_cell(x, y, z, ix, iy, iz)
|
||||
else:
|
||||
for iz, z in enumerate(zs):
|
||||
for iy, y in enumerate(ys):
|
||||
process_cell(x, y, z, ix, iy, iz)
|
||||
elif first_axes_processed == 'y':
|
||||
for iy, y in enumerate(ys):
|
||||
for ix, x in enumerate(xs):
|
||||
process_cell(x, y, ix, iy)
|
||||
if second_axes_processed == 'x':
|
||||
for ix, x in enumerate(xs):
|
||||
for iz, z in enumerate(zs):
|
||||
process_cell(x, y, z, ix, iy, iz)
|
||||
else:
|
||||
for iz, z in enumerate(zs):
|
||||
for ix, x in enumerate(xs):
|
||||
process_cell(x, y, z, ix, iy, iz)
|
||||
elif first_axes_processed == 'z':
|
||||
for iz, z in enumerate(zs):
|
||||
if second_axes_processed == 'x':
|
||||
for ix, x in enumerate(xs):
|
||||
for iy, y in enumerate(ys):
|
||||
process_cell(x, y, z, ix, iy, iz)
|
||||
else:
|
||||
for iy, y in enumerate(ys):
|
||||
for ix, x in enumerate(xs):
|
||||
process_cell(x, y, z, ix, iy, iz)
|
||||
|
||||
if not processed_result:
|
||||
print("Unexpected error: draw_xy_grid failed to return even a single processed image")
|
||||
print("Unexpected error: draw_xyz_grid failed to return even a single processed image")
|
||||
return Processed(p, [])
|
||||
|
||||
grid = images.image_grid(image_cache, rows=len(ys))
|
||||
sub_grids = [None] * len(zs)
|
||||
for i in range(len(zs)):
|
||||
start_index = i * len(xs) * len(ys)
|
||||
end_index = start_index + len(xs) * len(ys)
|
||||
grid = images.image_grid(image_cache[start_index:end_index], rows=len(ys))
|
||||
if draw_legend:
|
||||
grid = images.draw_grid_annotations(grid, cell_size[0], cell_size[1], hor_texts, ver_texts, margin_size)
|
||||
sub_grids[i] = grid
|
||||
if include_sub_grids and len(zs) > 1:
|
||||
processed_result.images.insert(i+1, grid)
|
||||
|
||||
sub_grid_size = sub_grids[0].size
|
||||
z_grid = images.image_grid(sub_grids, rows=1)
|
||||
if draw_legend:
|
||||
grid = images.draw_grid_annotations(grid, cell_size[0], cell_size[1], hor_texts, ver_texts)
|
||||
z_grid = images.draw_grid_annotations(z_grid, sub_grid_size[0], sub_grid_size[1], title_texts, [[images.GridAnnotation()]])
|
||||
processed_result.images[0] = z_grid
|
||||
|
||||
processed_result.images[0] = grid
|
||||
|
||||
return processed_result
|
||||
return processed_result, sub_grids
|
||||
|
||||
|
||||
class SharedSettingsStackHelper(object):
|
||||
|
@ -291,7 +328,7 @@ re_range_count_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+
|
|||
|
||||
class Script(scripts.Script):
|
||||
def title(self):
|
||||
return "X/Y plot"
|
||||
return "X/Y/Z plot"
|
||||
|
||||
def ui(self, is_img2img):
|
||||
self.current_axis_options = [x for x in axis_options if type(x) == AxisOption or x.is_img2img == is_img2img]
|
||||
|
@ -301,24 +338,42 @@ class Script(scripts.Script):
|
|||
with gr.Row():
|
||||
x_type = gr.Dropdown(label="X type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[1].label, type="index", elem_id=self.elem_id("x_type"))
|
||||
x_values = gr.Textbox(label="X values", lines=1, elem_id=self.elem_id("x_values"))
|
||||
fill_x_button = ToolButton(value=fill_values_symbol, elem_id="xy_grid_fill_x_tool_button", visible=False)
|
||||
fill_x_button = ToolButton(value=fill_values_symbol, elem_id="xyz_grid_fill_x_tool_button", visible=False)
|
||||
|
||||
with gr.Row():
|
||||
y_type = gr.Dropdown(label="Y type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[0].label, type="index", elem_id=self.elem_id("y_type"))
|
||||
y_values = gr.Textbox(label="Y values", lines=1, elem_id=self.elem_id("y_values"))
|
||||
fill_y_button = ToolButton(value=fill_values_symbol, elem_id="xy_grid_fill_y_tool_button", visible=False)
|
||||
fill_y_button = ToolButton(value=fill_values_symbol, elem_id="xyz_grid_fill_y_tool_button", visible=False)
|
||||
|
||||
with gr.Row():
|
||||
z_type = gr.Dropdown(label="Z type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[0].label, type="index", elem_id=self.elem_id("z_type"))
|
||||
z_values = gr.Textbox(label="Z values", lines=1, elem_id=self.elem_id("z_values"))
|
||||
fill_z_button = ToolButton(value=fill_values_symbol, elem_id="xyz_grid_fill_z_tool_button", visible=False)
|
||||
|
||||
with gr.Row(variant="compact", elem_id="axis_options"):
|
||||
draw_legend = gr.Checkbox(label='Draw legend', value=True, elem_id=self.elem_id("draw_legend"))
|
||||
include_lone_images = gr.Checkbox(label='Include Separate Images', value=False, elem_id=self.elem_id("include_lone_images"))
|
||||
no_fixed_seeds = gr.Checkbox(label='Keep -1 for seeds', value=False, elem_id=self.elem_id("no_fixed_seeds"))
|
||||
swap_axes_button = gr.Button(value="Swap axes", elem_id="xy_grid_swap_axes_button")
|
||||
with gr.Column():
|
||||
draw_legend = gr.Checkbox(label='Draw legend', value=True, elem_id=self.elem_id("draw_legend"))
|
||||
no_fixed_seeds = gr.Checkbox(label='Keep -1 for seeds', value=False, elem_id=self.elem_id("no_fixed_seeds"))
|
||||
with gr.Column():
|
||||
include_lone_images = gr.Checkbox(label='Include Sub Images', value=False, elem_id=self.elem_id("include_lone_images"))
|
||||
include_sub_grids = gr.Checkbox(label='Include Sub Grids', value=False, elem_id=self.elem_id("include_sub_grids"))
|
||||
with gr.Column():
|
||||
margin_size = gr.Slider(label="Grid margins (px)", min=0, max=500, value=0, step=2, elem_id=self.elem_id("margin_size"))
|
||||
|
||||
with gr.Row(variant="compact", elem_id="swap_axes"):
|
||||
swap_xy_axes_button = gr.Button(value="Swap X/Y axes", elem_id="xy_grid_swap_axes_button")
|
||||
swap_yz_axes_button = gr.Button(value="Swap Y/Z axes", elem_id="yz_grid_swap_axes_button")
|
||||
swap_xz_axes_button = gr.Button(value="Swap X/Z axes", elem_id="xz_grid_swap_axes_button")
|
||||
|
||||
def swap_axes(x_type, x_values, y_type, y_values):
|
||||
return self.current_axis_options[y_type].label, y_values, self.current_axis_options[x_type].label, x_values
|
||||
def swap_axes(axis1_type, axis1_values, axis2_type, axis2_values):
|
||||
return self.current_axis_options[axis2_type].label, axis2_values, self.current_axis_options[axis1_type].label, axis1_values
|
||||
|
||||
swap_args = [x_type, x_values, y_type, y_values]
|
||||
swap_axes_button.click(swap_axes, inputs=swap_args, outputs=swap_args)
|
||||
xy_swap_args = [x_type, x_values, y_type, y_values]
|
||||
swap_xy_axes_button.click(swap_axes, inputs=xy_swap_args, outputs=xy_swap_args)
|
||||
yz_swap_args = [y_type, y_values, z_type, z_values]
|
||||
swap_yz_axes_button.click(swap_axes, inputs=yz_swap_args, outputs=yz_swap_args)
|
||||
xz_swap_args = [x_type, x_values, z_type, z_values]
|
||||
swap_xz_axes_button.click(swap_axes, inputs=xz_swap_args, outputs=xz_swap_args)
|
||||
|
||||
def fill(x_type):
|
||||
axis = self.current_axis_options[x_type]
|
||||
|
@ -326,16 +381,27 @@ class Script(scripts.Script):
|
|||
|
||||
fill_x_button.click(fn=fill, inputs=[x_type], outputs=[x_values])
|
||||
fill_y_button.click(fn=fill, inputs=[y_type], outputs=[y_values])
|
||||
fill_z_button.click(fn=fill, inputs=[z_type], outputs=[z_values])
|
||||
|
||||
def select_axis(x_type):
|
||||
return gr.Button.update(visible=self.current_axis_options[x_type].choices is not None)
|
||||
|
||||
x_type.change(fn=select_axis, inputs=[x_type], outputs=[fill_x_button])
|
||||
y_type.change(fn=select_axis, inputs=[y_type], outputs=[fill_y_button])
|
||||
z_type.change(fn=select_axis, inputs=[z_type], outputs=[fill_z_button])
|
||||
|
||||
return [x_type, x_values, y_type, y_values, draw_legend, include_lone_images, no_fixed_seeds]
|
||||
self.infotext_fields = (
|
||||
(x_type, "X Type"),
|
||||
(x_values, "X Values"),
|
||||
(y_type, "Y Type"),
|
||||
(y_values, "Y Values"),
|
||||
(z_type, "Z Type"),
|
||||
(z_values, "Z Values"),
|
||||
)
|
||||
|
||||
def run(self, p, x_type, x_values, y_type, y_values, draw_legend, include_lone_images, no_fixed_seeds):
|
||||
return [x_type, x_values, y_type, y_values, z_type, z_values, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, margin_size]
|
||||
|
||||
def run(self, p, x_type, x_values, y_type, y_values, z_type, z_values, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, margin_size):
|
||||
if not no_fixed_seeds:
|
||||
modules.processing.fix_seed(p)
|
||||
|
||||
|
@ -409,6 +475,9 @@ class Script(scripts.Script):
|
|||
y_opt = self.current_axis_options[y_type]
|
||||
ys = process_axis(y_opt, y_values)
|
||||
|
||||
z_opt = self.current_axis_options[z_type]
|
||||
zs = process_axis(z_opt, z_values)
|
||||
|
||||
def fix_axis_seeds(axis_opt, axis_list):
|
||||
if axis_opt.label in ['Seed', 'Var. seed']:
|
||||
return [int(random.randrange(4294967294)) if val is None or val == '' or val == -1 else val for val in axis_list]
|
||||
|
@ -418,21 +487,26 @@ class Script(scripts.Script):
|
|||
if not no_fixed_seeds:
|
||||
xs = fix_axis_seeds(x_opt, xs)
|
||||
ys = fix_axis_seeds(y_opt, ys)
|
||||
zs = fix_axis_seeds(z_opt, zs)
|
||||
|
||||
if x_opt.label == 'Steps':
|
||||
total_steps = sum(xs) * len(ys)
|
||||
total_steps = sum(xs) * len(ys) * len(zs)
|
||||
elif y_opt.label == 'Steps':
|
||||
total_steps = sum(ys) * len(xs)
|
||||
total_steps = sum(ys) * len(xs) * len(zs)
|
||||
elif z_opt.label == 'Steps':
|
||||
total_steps = sum(zs) * len(xs) * len(ys)
|
||||
else:
|
||||
total_steps = p.steps * len(xs) * len(ys)
|
||||
total_steps = p.steps * len(xs) * len(ys) * len(zs)
|
||||
|
||||
if isinstance(p, StableDiffusionProcessingTxt2Img) and p.enable_hr:
|
||||
if x_opt.label == "Hires steps":
|
||||
total_steps += sum(xs) * len(ys)
|
||||
total_steps += sum(xs) * len(ys) * len(zs)
|
||||
elif y_opt.label == "Hires steps":
|
||||
total_steps += sum(ys) * len(xs)
|
||||
total_steps += sum(ys) * len(xs) * len(zs)
|
||||
elif z_opt.label == "Hires steps":
|
||||
total_steps += sum(zs) * len(xs) * len(ys)
|
||||
elif p.hr_second_pass_steps:
|
||||
total_steps += p.hr_second_pass_steps * len(xs) * len(ys)
|
||||
total_steps += p.hr_second_pass_steps * len(xs) * len(ys) * len(zs)
|
||||
else:
|
||||
total_steps *= 2
|
||||
|
||||
|
@ -440,7 +514,8 @@ class Script(scripts.Script):
|
|||
|
||||
image_cell_count = p.n_iter * p.batch_size
|
||||
cell_console_text = f"; {image_cell_count} images per cell" if image_cell_count > 1 else ""
|
||||
print(f"X/Y plot will create {len(xs) * len(ys) * image_cell_count} images on a {len(xs)}x{len(ys)} grid{cell_console_text}. (Total steps to process: {total_steps})")
|
||||
plural_s = 's' if len(zs) > 1 else ''
|
||||
print(f"X/Y/Z plot will create {len(xs) * len(ys) * len(zs) * image_cell_count} images on {len(zs)} {len(xs)}x{len(ys)} grid{plural_s}{cell_console_text}. (Total steps to process: {total_steps})")
|
||||
shared.total_tqdm.updateTotal(total_steps)
|
||||
|
||||
grid_infotext = [None]
|
||||
|
@ -448,20 +523,42 @@ class Script(scripts.Script):
|
|||
# If one of the axes is very slow to change between (like SD model
|
||||
# checkpoint), then make sure it is in the outer iteration of the nested
|
||||
# `for` loop.
|
||||
swap_axes_processing_order = x_opt.cost > y_opt.cost
|
||||
first_axes_processed = 'x'
|
||||
second_axes_processed = 'y'
|
||||
if x_opt.cost > y_opt.cost and x_opt.cost > z_opt.cost:
|
||||
first_axes_processed = 'x'
|
||||
if y_opt.cost > z_opt.cost:
|
||||
second_axes_processed = 'y'
|
||||
else:
|
||||
second_axes_processed = 'z'
|
||||
elif y_opt.cost > x_opt.cost and y_opt.cost > z_opt.cost:
|
||||
first_axes_processed = 'y'
|
||||
if x_opt.cost > z_opt.cost:
|
||||
second_axes_processed = 'x'
|
||||
else:
|
||||
second_axes_processed = 'z'
|
||||
elif z_opt.cost > x_opt.cost and z_opt.cost > y_opt.cost:
|
||||
first_axes_processed = 'z'
|
||||
if x_opt.cost > y_opt.cost:
|
||||
second_axes_processed = 'x'
|
||||
else:
|
||||
second_axes_processed = 'y'
|
||||
|
||||
def cell(x, y):
|
||||
def cell(x, y, z):
|
||||
if shared.state.interrupted:
|
||||
return Processed(p, [], p.seed, "")
|
||||
|
||||
pc = copy(p)
|
||||
pc.styles = pc.styles[:]
|
||||
x_opt.apply(pc, x, xs)
|
||||
y_opt.apply(pc, y, ys)
|
||||
z_opt.apply(pc, z, zs)
|
||||
|
||||
res = process_images(pc)
|
||||
|
||||
if grid_infotext[0] is None:
|
||||
pc.extra_generation_params = copy(pc.extra_generation_params)
|
||||
pc.extra_generation_params['Script'] = self.title()
|
||||
|
||||
if x_opt.label != 'Nothing':
|
||||
pc.extra_generation_params["X Type"] = x_opt.label
|
||||
|
@ -475,24 +572,39 @@ class Script(scripts.Script):
|
|||
if y_opt.label in ["Seed", "Var. seed"] and not no_fixed_seeds:
|
||||
pc.extra_generation_params["Fixed Y Values"] = ", ".join([str(y) for y in ys])
|
||||
|
||||
if z_opt.label != 'Nothing':
|
||||
pc.extra_generation_params["Z Type"] = z_opt.label
|
||||
pc.extra_generation_params["Z Values"] = z_values
|
||||
if z_opt.label in ["Seed", "Var. seed"] and not no_fixed_seeds:
|
||||
pc.extra_generation_params["Fixed Z Values"] = ", ".join([str(z) for z in zs])
|
||||
|
||||
grid_infotext[0] = processing.create_infotext(pc, pc.all_prompts, pc.all_seeds, pc.all_subseeds)
|
||||
|
||||
return res
|
||||
|
||||
with SharedSettingsStackHelper():
|
||||
processed = draw_xy_grid(
|
||||
processed, sub_grids = draw_xyz_grid(
|
||||
p,
|
||||
xs=xs,
|
||||
ys=ys,
|
||||
zs=zs,
|
||||
x_labels=[x_opt.format_value(p, x_opt, x) for x in xs],
|
||||
y_labels=[y_opt.format_value(p, y_opt, y) for y in ys],
|
||||
z_labels=[z_opt.format_value(p, z_opt, z) for z in zs],
|
||||
cell=cell,
|
||||
draw_legend=draw_legend,
|
||||
include_lone_images=include_lone_images,
|
||||
swap_axes_processing_order=swap_axes_processing_order
|
||||
include_sub_grids=include_sub_grids,
|
||||
first_axes_processed=first_axes_processed,
|
||||
second_axes_processed=second_axes_processed,
|
||||
margin_size=margin_size
|
||||
)
|
||||
|
||||
if opts.grid_save and len(sub_grids) > 1:
|
||||
for sub_grid in sub_grids:
|
||||
images.save_image(sub_grid, p.outpath_grids, "xyz_grid", info=grid_infotext[0], extension=opts.grid_format, prompt=p.prompt, seed=processed.seed, grid=True, p=p)
|
||||
|
||||
if opts.grid_save:
|
||||
images.save_image(processed.images[0], p.outpath_grids, "xy_grid", info=grid_infotext[0], extension=opts.grid_format, prompt=p.prompt, seed=processed.seed, grid=True, p=p)
|
||||
images.save_image(processed.images[0], p.outpath_grids, "xyz_grid", info=grid_infotext[0], extension=opts.grid_format, prompt=p.prompt, seed=processed.seed, grid=True, p=p)
|
||||
|
||||
return processed
|
53
style.css
53
style.css
|
@ -74,7 +74,12 @@
|
|||
#txt2img_gallery img, #img2img_gallery img{
|
||||
object-fit: scale-down;
|
||||
}
|
||||
|
||||
#txt2img_actions_column, #img2img_actions_column {
|
||||
margin: 0.35rem 0.75rem 0.35rem 0;
|
||||
}
|
||||
#script_list {
|
||||
padding: .625rem .75rem 0 .625rem;
|
||||
}
|
||||
.justify-center.overflow-x-scroll {
|
||||
justify-content: left;
|
||||
}
|
||||
|
@ -126,6 +131,7 @@
|
|||
|
||||
#txt2img_actions_column, #img2img_actions_column{
|
||||
gap: 0;
|
||||
margin-right: .75rem;
|
||||
}
|
||||
|
||||
#txt2img_tools, #img2img_tools{
|
||||
|
@ -150,6 +156,7 @@
|
|||
|
||||
#txt2img_styles_row, #img2img_styles_row{
|
||||
gap: 0.25em;
|
||||
margin-top: 0.3em;
|
||||
}
|
||||
|
||||
#txt2img_styles_row > button, #img2img_styles_row > button{
|
||||
|
@ -164,7 +171,7 @@
|
|||
min-height: 3.2em;
|
||||
}
|
||||
|
||||
#txt2img_styles ul, #img2img_styles ul{
|
||||
ul.list-none{
|
||||
max-height: 35em;
|
||||
z-index: 2000;
|
||||
}
|
||||
|
@ -311,11 +318,11 @@ input[type="range"]{
|
|||
.min-h-\[6rem\] { min-height: unset !important; }
|
||||
|
||||
.progressDiv{
|
||||
position: absolute;
|
||||
position: relative;
|
||||
height: 20px;
|
||||
top: -20px;
|
||||
background: #b4c0cc;
|
||||
border-radius: 3px !important;
|
||||
margin-bottom: -3px;
|
||||
}
|
||||
|
||||
.dark .progressDiv{
|
||||
|
@ -535,7 +542,7 @@ input[type="range"]{
|
|||
}
|
||||
|
||||
#quicksettings {
|
||||
gap: 0.4em;
|
||||
width: fit-content;
|
||||
}
|
||||
|
||||
#quicksettings > div, #quicksettings > fieldset{
|
||||
|
@ -545,6 +552,7 @@ input[type="range"]{
|
|||
border: none;
|
||||
box-shadow: none;
|
||||
background: none;
|
||||
margin-right: 10px;
|
||||
}
|
||||
|
||||
#quicksettings > div > div > div > label > span {
|
||||
|
@ -567,7 +575,7 @@ canvas[key="mask"] {
|
|||
right: 0.5em;
|
||||
top: -0.6em;
|
||||
z-index: 400;
|
||||
width: 8em;
|
||||
width: 6em;
|
||||
}
|
||||
#quicksettings .gr-box > div > div > input.gr-text-input {
|
||||
top: -1.12em;
|
||||
|
@ -665,11 +673,27 @@ canvas[key="mask"] {
|
|||
|
||||
#quicksettings .gr-button-tool{
|
||||
margin: 0;
|
||||
border-color: unset;
|
||||
background-color: unset;
|
||||
}
|
||||
|
||||
|
||||
#modelmerger_interp_description>p {
|
||||
margin: 0!important;
|
||||
text-align: center;
|
||||
}
|
||||
#modelmerger_interp_description {
|
||||
margin: 0.35rem 0.75rem 1.23rem;
|
||||
}
|
||||
#img2img_settings > div.gr-form, #txt2img_settings > div.gr-form {
|
||||
padding-top: 0.9em;
|
||||
padding-bottom: 0.9em;
|
||||
}
|
||||
#txt2img_settings {
|
||||
padding-top: 1.16em;
|
||||
padding-bottom: 0.9em;
|
||||
}
|
||||
#img2img_settings {
|
||||
padding-bottom: 0.9em;
|
||||
}
|
||||
|
||||
#img2img_settings div.gr-form .gr-form, #txt2img_settings div.gr-form .gr-form, #train_tabs div.gr-form .gr-form{
|
||||
|
@ -714,9 +738,6 @@ footer {
|
|||
white-space: nowrap;
|
||||
min-width: auto;
|
||||
}
|
||||
#txt2img_hires_fix{
|
||||
margin-left: -0.8em;
|
||||
}
|
||||
|
||||
#img2img_copy_to_img2img, #img2img_copy_to_sketch, #img2img_copy_to_inpaint, #img2img_copy_to_inpaint_sketch{
|
||||
margin-left: 0em;
|
||||
|
@ -744,7 +765,7 @@ footer {
|
|||
|
||||
.dark .gr-compact{
|
||||
background-color: rgb(31 41 55 / var(--tw-bg-opacity));
|
||||
margin-left: 0.8em;
|
||||
margin-left: 0;
|
||||
}
|
||||
|
||||
.gr-compact{
|
||||
|
@ -786,7 +807,13 @@ footer {
|
|||
margin: 0.3em;
|
||||
}
|
||||
|
||||
.extra-network-subdirs{
|
||||
padding: 0.2em 0.35em;
|
||||
}
|
||||
|
||||
.extra-network-subdirs button{
|
||||
margin: 0 0.15em;
|
||||
}
|
||||
|
||||
#txt2img_extra_networks .search, #img2img_extra_networks .search{
|
||||
display: inline-block;
|
||||
|
@ -857,6 +884,7 @@ footer {
|
|||
white-space: nowrap;
|
||||
text-overflow: ellipsis;
|
||||
background: rgba(0,0,0,.5);
|
||||
color: white;
|
||||
}
|
||||
|
||||
.extra-network-thumbs .card:hover .actions .name {
|
||||
|
@ -928,3 +956,6 @@ footer {
|
|||
color: red;
|
||||
}
|
||||
|
||||
[id*='_prompt_container'] > div {
|
||||
margin: 0!important;
|
||||
}
|
||||
|
|
|
@ -10,7 +10,7 @@ then
|
|||
fi
|
||||
|
||||
export install_dir="$HOME"
|
||||
export COMMANDLINE_ARGS="--skip-torch-cuda-test --no-half --use-cpu interrogate"
|
||||
export COMMANDLINE_ARGS="--skip-torch-cuda-test --upcast-sampling --no-half-vae --use-cpu interrogate"
|
||||
export TORCH_COMMAND="pip install torch==1.12.1 torchvision==0.13.1"
|
||||
export K_DIFFUSION_REPO="https://github.com/brkirch/k-diffusion.git"
|
||||
export K_DIFFUSION_COMMIT_HASH="51c9778f269cedb55a4d88c79c0246d35bdadb71"
|
||||
|
|
|
@ -3,6 +3,6 @@
|
|||
set PYTHON=
|
||||
set GIT=
|
||||
set VENV_DIR=
|
||||
set COMMANDLINE_ARGS=
|
||||
set COMMANDLINE_ARGS=--skip-torch-cuda-test --precision full --no-half
|
||||
|
||||
call webui.bat
|
||||
|
|
13
webui.bat
13
webui.bat
|
@ -9,10 +9,19 @@ set ERROR_REPORTING=FALSE
|
|||
mkdir tmp 2>NUL
|
||||
|
||||
%PYTHON% -c "" >tmp/stdout.txt 2>tmp/stderr.txt
|
||||
if %ERRORLEVEL% == 0 goto :start_venv
|
||||
if %ERRORLEVEL% == 0 goto :check_pip
|
||||
echo Couldn't launch python
|
||||
goto :show_stdout_stderr
|
||||
|
||||
:check_pip
|
||||
%PYTHON% -mpip --help >tmp/stdout.txt 2>tmp/stderr.txt
|
||||
if %ERRORLEVEL% == 0 goto :start_venv
|
||||
if "%PIP_INSTALLER_LOCATION%" == "" goto :show_stdout_stderr
|
||||
%PYTHON% "%PIP_INSTALLER_LOCATION%" >tmp/stdout.txt 2>tmp/stderr.txt
|
||||
if %ERRORLEVEL% == 0 goto :start_venv
|
||||
echo Couldn't install pip
|
||||
goto :show_stdout_stderr
|
||||
|
||||
:start_venv
|
||||
if ["%VENV_DIR%"] == ["-"] goto :skip_venv
|
||||
if ["%SKIP_VENV%"] == ["1"] goto :skip_venv
|
||||
|
@ -46,7 +55,7 @@ pause
|
|||
exit /b
|
||||
|
||||
:accelerate_launch
|
||||
echo "Accelerating"
|
||||
echo Accelerating
|
||||
%ACCELERATE% launch --num_cpu_threads_per_process=6 launch.py
|
||||
pause
|
||||
exit /b
|
||||
|
|
17
webui.py
17
webui.py
|
@ -12,10 +12,9 @@ from packaging import version
|
|||
import logging
|
||||
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
|
||||
|
||||
from modules import import_hook, errors, extra_networks
|
||||
from modules import import_hook, errors, extra_networks, ui_extra_networks_checkpoints
|
||||
from modules import extra_networks_hypernet, ui_extra_networks_hypernets, ui_extra_networks_textual_inversion
|
||||
from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call
|
||||
from modules.paths import script_path
|
||||
|
||||
import torch
|
||||
|
||||
|
@ -53,6 +52,9 @@ else:
|
|||
|
||||
|
||||
def check_versions():
|
||||
if shared.cmd_opts.skip_version_check:
|
||||
return
|
||||
|
||||
expected_torch_version = "1.13.1"
|
||||
|
||||
if version.parse(torch.__version__) < version.parse(expected_torch_version):
|
||||
|
@ -60,7 +62,10 @@ def check_versions():
|
|||
You are running torch {torch.__version__}.
|
||||
The program is tested to work with torch {expected_torch_version}.
|
||||
To reinstall the desired version, run with commandline flag --reinstall-torch.
|
||||
Beware that this will cause a lot of large files to be downloaded.
|
||||
Beware that this will cause a lot of large files to be downloaded, as well as
|
||||
there are reports of issues with training tab on the latest version.
|
||||
|
||||
Use --skip-version-check commandline argument to disable this check.
|
||||
""".strip())
|
||||
|
||||
expected_xformers_version = "0.0.16rc425"
|
||||
|
@ -72,6 +77,8 @@ Beware that this will cause a lot of large files to be downloaded.
|
|||
You are running xformers {xformers.__version__}.
|
||||
The program is tested to work with xformers {expected_xformers_version}.
|
||||
To reinstall the desired version, run with commandline flag --reinstall-xformers.
|
||||
|
||||
Use --skip-version-check commandline argument to disable this check.
|
||||
""".strip())
|
||||
|
||||
|
||||
|
@ -120,6 +127,7 @@ def initialize():
|
|||
ui_extra_networks.intialize()
|
||||
ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion())
|
||||
ui_extra_networks.register_page(ui_extra_networks_hypernets.ExtraNetworksPageHypernetworks())
|
||||
ui_extra_networks.register_page(ui_extra_networks_checkpoints.ExtraNetworksPageCheckpoints())
|
||||
|
||||
extra_networks.initialize()
|
||||
extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet())
|
||||
|
@ -228,6 +236,8 @@ def webui():
|
|||
if launch_api:
|
||||
create_api(app)
|
||||
|
||||
ui_extra_networks.add_pages_to_demo(app)
|
||||
|
||||
modules.script_callbacks.app_started_callback(shared.demo, app)
|
||||
|
||||
wait_on_server(shared.demo)
|
||||
|
@ -255,6 +265,7 @@ def webui():
|
|||
ui_extra_networks.intialize()
|
||||
ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion())
|
||||
ui_extra_networks.register_page(ui_extra_networks_hypernets.ExtraNetworksPageHypernetworks())
|
||||
ui_extra_networks.register_page(ui_extra_networks_checkpoints.ExtraNetworksPageCheckpoints())
|
||||
|
||||
extra_networks.initialize()
|
||||
extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet())
|
||||
|
|
Loading…
Reference in New Issue
Block a user