Merge branch 'master' into saving
This commit is contained in:
commit
a9d7eb722f
8
.gitignore
vendored
8
.gitignore
vendored
|
@ -1,10 +1,13 @@
|
|||
__pycache__
|
||||
/ESRGAN
|
||||
*.ckpt
|
||||
*.pth
|
||||
/ESRGAN/*
|
||||
/SwinIR/*
|
||||
/repositories
|
||||
/venv
|
||||
/tmp
|
||||
/model.ckpt
|
||||
/models/**/*.ckpt
|
||||
/models/**/*
|
||||
/GFPGANv1.3.pth
|
||||
/gfpgan/weights/*.pth
|
||||
/ui-config.json
|
||||
|
@ -22,3 +25,4 @@ __pycache__
|
|||
/.idea
|
||||
notification.mp3
|
||||
/SwinIR
|
||||
/textual_inversion
|
||||
|
|
59
README.md
59
README.md
|
@ -3,50 +3,64 @@ A browser interface based on Gradio library for Stable Diffusion.
|
|||
|
||||
![](txt2img_Screenshot.png)
|
||||
|
||||
Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Scripts) wiki page for extra scripts developed by users.
|
||||
|
||||
## Features
|
||||
[Detailed feature showcase with images](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features):
|
||||
- Original txt2img and img2img modes
|
||||
- One click install and run script (but you still must install python and git)
|
||||
- Outpainting
|
||||
- Inpainting
|
||||
- Prompt matrix
|
||||
- Stable Diffusion upscale
|
||||
- Attention
|
||||
- Loopback
|
||||
- X/Y plot
|
||||
- Prompt Matrix
|
||||
- Stable Diffusion Upscale
|
||||
- Attention, specify parts of text that the model should pay more attention to
|
||||
- a man in a ((tuxedo)) - will pay more attention to tuxedo
|
||||
- a man in a (tuxedo:1.21) - alternative syntax
|
||||
- Loopback, run img2img processing multiple times
|
||||
- X/Y plot, a way to draw a 2 dimensional plot of images with different parameters
|
||||
- Textual Inversion
|
||||
- have as many embeddings as you want and use any names you like for them
|
||||
- use multiple embeddings with different numbers of vectors per token
|
||||
- works with half precision floating point numbers
|
||||
- Extras tab with:
|
||||
- GFPGAN, neural network that fixes faces
|
||||
- CodeFormer, face restoration tool as an alternative to GFPGAN
|
||||
- RealESRGAN, neural network upscaler
|
||||
- ESRGAN, neural network with a lot of third party models
|
||||
- ESRGAN, neural network upscaler with a lot of third party models
|
||||
- SwinIR, neural network upscaler
|
||||
- LDSR, Latent diffusion super resolution upscaling
|
||||
- Resizing aspect ratio options
|
||||
- Sampling method selection
|
||||
- Interrupt processing at any time
|
||||
- 4GB video card support
|
||||
- Correct seeds for batches
|
||||
- 4GB video card support (also reports of 2GB working)
|
||||
- Correct seeds for batches
|
||||
- Prompt length validation
|
||||
- Generation parameters added as text to PNG
|
||||
- Tab to view an existing picture's generation parameters
|
||||
- get length of prompt in tokens as you type
|
||||
- get a warning after generation if some text was truncated
|
||||
- Generation parameters
|
||||
- parameters you used to generate images are saved with that image
|
||||
- in PNG chunks for PNG, in EXIF for JPEG
|
||||
- can drag the image to PNG info tab to restore generation parameters and automatically copy them into UI
|
||||
- can be disabled in settings
|
||||
- Settings page
|
||||
- Running custom code from UI
|
||||
- Running arbitrary python code from UI (must run with --allow-code to enable)
|
||||
- Mouseover hints for most UI elements
|
||||
- Possible to change defaults/mix/max/step values for UI elements via text config
|
||||
- Random artist button
|
||||
- Tiling support: UI checkbox to create images that can be tiled like textures
|
||||
- Tiling support, a checkbox to create images that can be tiled like textures
|
||||
- Progress bar and live image generation preview
|
||||
- Negative prompt
|
||||
- Styles
|
||||
- Variations
|
||||
- Seed resizing
|
||||
- CLIP interrogator
|
||||
- Prompt Editing
|
||||
- Batch Processing
|
||||
- Negative prompt, an extra text field that allows you to list what you don't want to see in generated image
|
||||
- Styles, a way to save part of prompt and easily apply them via dropdown later
|
||||
- Variations, a way to generate same image but with tiny differences
|
||||
- Seed resizing, a way to generate same image but at slightly different resolution
|
||||
- CLIP interrogator, a button that tries to guess prompt from an image
|
||||
- Prompt Editing, a way to change prompt mid-generation, say to start making a watermelon and switch to anime girl midway
|
||||
- Batch Processing, process a group of files using img2img
|
||||
- Img2img Alternative
|
||||
- Highres Fix
|
||||
- LDSR Upscaling
|
||||
- Highres Fix, a convenience option to produce high resolution pictures in one click without usual distortions
|
||||
- Reloading checkpoints on the fly
|
||||
- Checkpoint Merger, a tab that allows you to merge two checkpoints into one
|
||||
- [Custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Scripts) with many extensions from community
|
||||
|
||||
## Installation and Running
|
||||
Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for both [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended) and [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs.
|
||||
|
@ -83,6 +97,9 @@ bash <(wget -qO- https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusio
|
|||
|
||||
Find the instructions [here](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Installation-on-Apple-Silicon).
|
||||
|
||||
## Contributing
|
||||
Here's how to add code to this repo: [Contributing](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Contributing)
|
||||
|
||||
## Documentation
|
||||
The documentation was moved from this README over to the project's [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki).
|
||||
|
||||
|
|
|
@ -1 +0,0 @@
|
|||
|
|
@ -359,7 +359,6 @@ Antanas Sutkus,0.7369492,black-white
|
|||
Leonora Carrington,0.73726475,scribbles
|
||||
Hieronymus Bosch,0.7369955,scribbles
|
||||
A. J. Casson,0.73666203,scribbles
|
||||
A.J.Casson,0.73666203,scribbles
|
||||
Chaim Soutine,0.73662066,scribbles
|
||||
Artur Bordalo,0.7364549,weird
|
||||
Thomas Allom,0.68792284,fineart
|
||||
|
@ -1907,7 +1906,6 @@ Alex Schomburg,0.46614102,digipa-low-impact
|
|||
Bastien L. Deharme,0.583349,special
|
||||
František Jakub Prokyš,0.58782333,fineart
|
||||
Jesper Ejsing,0.58782053,fineart
|
||||
Jesper Ejsing,0.58782053,fineart
|
||||
Odd Nerdrum,0.53551745,digipa-high-impact
|
||||
Tom Lovell,0.5877577,fineart
|
||||
Ayami Kojima,0.5877416,fineart
|
||||
|
|
|
|
@ -58,8 +58,8 @@ titles = {
|
|||
|
||||
"Interrogate": "Reconstruct prompt from existing image and put it into the prompt field.",
|
||||
|
||||
"Images filename pattern": "Use following tags to define how filenames for images are chosen: [steps], [cfg], [prompt], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [prompt_words], [date], [datetime], [job_timestamp]; leave empty for default.",
|
||||
"Directory name pattern": "Use following tags to define how subdirectories for images and grids are chosen: [steps], [cfg], [prompt], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [prompt_words], [date], [datetime], [job_timestamp]; leave empty for default.",
|
||||
"Images filename pattern": "Use following tags to define how filenames for images are chosen: [steps], [cfg], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [prompt_words], [date], [datetime], [job_timestamp]; leave empty for default.",
|
||||
"Directory name pattern": "Use following tags to define how subdirectories for images and grids are chosen: [steps], [cfg], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [prompt_words], [date], [datetime], [job_timestamp]; leave empty for default.",
|
||||
"Max prompt words": "Set the maximum number of words to be used in the [prompt_words] option; ATTENTION: If the words are too long, they may exceed the maximum length of the file path that the system can handle",
|
||||
|
||||
"Loopback": "Process an image, use it as an input, repeat.",
|
||||
|
|
|
@ -30,6 +30,7 @@ function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_inte
|
|||
onUiUpdate(function(){
|
||||
check_progressbar('txt2img', 'txt2img_progressbar', 'txt2img_progress_span', 'txt2img_interrupt', 'txt2img_preview', 'txt2img_gallery')
|
||||
check_progressbar('img2img', 'img2img_progressbar', 'img2img_progress_span', 'img2img_interrupt', 'img2img_preview', 'img2img_gallery')
|
||||
check_progressbar('ti', 'ti_progressbar', 'ti_progress_span', 'ti_interrupt', 'ti_preview', 'ti_gallery')
|
||||
})
|
||||
|
||||
function requestMoreProgress(id_part, id_progressbar_span, id_interrupt){
|
||||
|
|
8
javascript/textualInversion.js
Normal file
8
javascript/textualInversion.js
Normal file
|
@ -0,0 +1,8 @@
|
|||
|
||||
|
||||
function start_training_textual_inversion(){
|
||||
requestProgress('ti')
|
||||
gradioApp().querySelector('#ti_error').innerHTML=''
|
||||
|
||||
return args_to_array(arguments)
|
||||
}
|
|
@ -186,10 +186,12 @@ onUiUpdate(function(){
|
|||
if (!txt2img_textarea) {
|
||||
txt2img_textarea = gradioApp().querySelector("#txt2img_prompt > label > textarea");
|
||||
txt2img_textarea?.addEventListener("input", () => update_token_counter("txt2img_token_button"));
|
||||
txt2img_textarea?.addEventListener("keyup", (event) => submit_prompt(event, "txt2img_generate"));
|
||||
}
|
||||
if (!img2img_textarea) {
|
||||
img2img_textarea = gradioApp().querySelector("#img2img_prompt > label > textarea");
|
||||
img2img_textarea?.addEventListener("input", () => update_token_counter("img2img_token_button"));
|
||||
img2img_textarea?.addEventListener("keyup", (event) => submit_prompt(event, "img2img_generate"));
|
||||
}
|
||||
})
|
||||
|
||||
|
@ -197,6 +199,14 @@ let txt2img_textarea, img2img_textarea = undefined;
|
|||
let wait_time = 800
|
||||
let token_timeout;
|
||||
|
||||
function submit_prompt(event, generate_button_id) {
|
||||
if (event.altKey && event.keyCode === 13) {
|
||||
event.preventDefault();
|
||||
gradioApp().getElementById(generate_button_id).click();
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
function update_token_counter(button_id) {
|
||||
if (token_timeout)
|
||||
clearTimeout(token_timeout);
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
# this scripts installs necessary requirements and launches main program in webui.py
|
||||
|
||||
import subprocess
|
||||
import os
|
||||
import sys
|
||||
|
@ -19,10 +18,9 @@ gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "git+https://github.com/Tencen
|
|||
|
||||
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "69ae4b35e0a0f6ee1af8bb9a5d0016ccb27e36dc")
|
||||
taming_transformers_commit_hash = os.environ.get('TAMING_TRANSFORMERS_COMMIT_HASH', "24268930bf1dce879235a7fddd0b2355b84d7ea6")
|
||||
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "9e3002b7cd64df7870e08527b7664eb2f2f5f3f5")
|
||||
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "a7ec1974d4ccb394c2dca275f42cd97490618924")
|
||||
codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
|
||||
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
|
||||
ldsr_commit_hash = os.environ.get('LDSR_COMMIT_HASH', "abf33e7002d59d9085081bce93ec798dcabd49af")
|
||||
|
||||
args = shlex.split(commandline_args)
|
||||
|
||||
|
@ -120,8 +118,6 @@ git_clone("https://github.com/CompVis/taming-transformers.git", repo_dir('taming
|
|||
git_clone("https://github.com/crowsonkb/k-diffusion.git", repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash)
|
||||
git_clone("https://github.com/sczhou/CodeFormer.git", repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash)
|
||||
git_clone("https://github.com/salesforce/BLIP.git", repo_dir('BLIP'), "BLIP", blip_commit_hash)
|
||||
# Using my repo until my changes are merged, as this makes interfacing with our version of SD-web a lot easier
|
||||
git_clone("https://github.com/Hafiidz/latent-diffusion", repo_dir('latent-diffusion'), "LDSR", ldsr_commit_hash)
|
||||
|
||||
if not is_installed("lpips"):
|
||||
run_pip(f"install -r {os.path.join(repo_dir('CodeFormer'), 'requirements.txt')}", "requirements for CodeFormer")
|
||||
|
@ -130,6 +126,9 @@ run_pip(f"install -r {requirements_file}", "requirements for Web UI")
|
|||
|
||||
sys.argv += args
|
||||
|
||||
if "--exit" in args:
|
||||
print("Exiting because of --exit argument")
|
||||
exit(0)
|
||||
|
||||
def start_webui():
|
||||
print(f"Launching Web UI with arguments: {' '.join(sys.argv[1:])}")
|
||||
|
|
78
modules/bsrgan_model.py
Normal file
78
modules/bsrgan_model.py
Normal file
|
@ -0,0 +1,78 @@
|
|||
import os.path
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
import PIL.Image
|
||||
import numpy as np
|
||||
import torch
|
||||
from basicsr.utils.download_util import load_file_from_url
|
||||
|
||||
import modules.upscaler
|
||||
from modules import shared, modelloader
|
||||
from modules.bsrgan_model_arch import RRDBNet
|
||||
from modules.paths import models_path
|
||||
|
||||
|
||||
class UpscalerBSRGAN(modules.upscaler.Upscaler):
|
||||
def __init__(self, dirname):
|
||||
self.name = "BSRGAN"
|
||||
self.model_path = os.path.join(models_path, self.name)
|
||||
self.model_name = "BSRGAN 4x"
|
||||
self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/BSRGAN.pth"
|
||||
self.user_path = dirname
|
||||
super().__init__()
|
||||
model_paths = self.find_models(ext_filter=[".pt", ".pth"])
|
||||
scalers = []
|
||||
if len(model_paths) == 0:
|
||||
scaler_data = modules.upscaler.UpscalerData(self.model_name, self.model_url, self, 4)
|
||||
scalers.append(scaler_data)
|
||||
for file in model_paths:
|
||||
if "http" in file:
|
||||
name = self.model_name
|
||||
else:
|
||||
name = modelloader.friendly_name(file)
|
||||
try:
|
||||
scaler_data = modules.upscaler.UpscalerData(name, file, self, 4)
|
||||
scalers.append(scaler_data)
|
||||
except Exception:
|
||||
print(f"Error loading BSRGAN model: {file}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
self.scalers = scalers
|
||||
|
||||
def do_upscale(self, img: PIL.Image, selected_file):
|
||||
torch.cuda.empty_cache()
|
||||
model = self.load_model(selected_file)
|
||||
if model is None:
|
||||
return img
|
||||
model.to(shared.device)
|
||||
torch.cuda.empty_cache()
|
||||
img = np.array(img)
|
||||
img = img[:, :, ::-1]
|
||||
img = np.moveaxis(img, 2, 0) / 255
|
||||
img = torch.from_numpy(img).float()
|
||||
img = img.unsqueeze(0).to(shared.device)
|
||||
with torch.no_grad():
|
||||
output = model(img)
|
||||
output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
|
||||
output = 255. * np.moveaxis(output, 0, 2)
|
||||
output = output.astype(np.uint8)
|
||||
output = output[:, :, ::-1]
|
||||
torch.cuda.empty_cache()
|
||||
return PIL.Image.fromarray(output, 'RGB')
|
||||
|
||||
def load_model(self, path: str):
|
||||
if "http" in path:
|
||||
filename = load_file_from_url(url=self.model_url, model_dir=self.model_path, file_name="%s.pth" % self.name,
|
||||
progress=True)
|
||||
else:
|
||||
filename = path
|
||||
if not os.path.exists(filename) or filename is None:
|
||||
print(f"BSRGAN: Unable to load model from {filename}", file=sys.stderr)
|
||||
return None
|
||||
model = RRDBNet(in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=4) # define network
|
||||
model.load_state_dict(torch.load(filename), strict=True)
|
||||
model.eval()
|
||||
for k, v in model.named_parameters():
|
||||
v.requires_grad = False
|
||||
return model
|
||||
|
102
modules/bsrgan_model_arch.py
Normal file
102
modules/bsrgan_model_arch.py
Normal file
|
@ -0,0 +1,102 @@
|
|||
import functools
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.nn.init as init
|
||||
|
||||
|
||||
def initialize_weights(net_l, scale=1):
|
||||
if not isinstance(net_l, list):
|
||||
net_l = [net_l]
|
||||
for net in net_l:
|
||||
for m in net.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
init.kaiming_normal_(m.weight, a=0, mode='fan_in')
|
||||
m.weight.data *= scale # for residual block
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
elif isinstance(m, nn.Linear):
|
||||
init.kaiming_normal_(m.weight, a=0, mode='fan_in')
|
||||
m.weight.data *= scale
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
init.constant_(m.weight, 1)
|
||||
init.constant_(m.bias.data, 0.0)
|
||||
|
||||
|
||||
def make_layer(block, n_layers):
|
||||
layers = []
|
||||
for _ in range(n_layers):
|
||||
layers.append(block())
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
|
||||
class ResidualDenseBlock_5C(nn.Module):
|
||||
def __init__(self, nf=64, gc=32, bias=True):
|
||||
super(ResidualDenseBlock_5C, self).__init__()
|
||||
# gc: growth channel, i.e. intermediate channels
|
||||
self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
|
||||
self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
|
||||
self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
|
||||
self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
|
||||
self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||
|
||||
# initialization
|
||||
initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
|
||||
|
||||
def forward(self, x):
|
||||
x1 = self.lrelu(self.conv1(x))
|
||||
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
|
||||
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
|
||||
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
|
||||
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
||||
return x5 * 0.2 + x
|
||||
|
||||
|
||||
class RRDB(nn.Module):
|
||||
'''Residual in Residual Dense Block'''
|
||||
|
||||
def __init__(self, nf, gc=32):
|
||||
super(RRDB, self).__init__()
|
||||
self.RDB1 = ResidualDenseBlock_5C(nf, gc)
|
||||
self.RDB2 = ResidualDenseBlock_5C(nf, gc)
|
||||
self.RDB3 = ResidualDenseBlock_5C(nf, gc)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.RDB1(x)
|
||||
out = self.RDB2(out)
|
||||
out = self.RDB3(out)
|
||||
return out * 0.2 + x
|
||||
|
||||
|
||||
class RRDBNet(nn.Module):
|
||||
def __init__(self, in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=4):
|
||||
super(RRDBNet, self).__init__()
|
||||
RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
|
||||
self.sf = sf
|
||||
|
||||
self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
|
||||
self.RRDB_trunk = make_layer(RRDB_block_f, nb)
|
||||
self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
#### upsampling
|
||||
self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
if self.sf==4:
|
||||
self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
|
||||
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
fea = self.conv_first(x)
|
||||
trunk = self.trunk_conv(self.RRDB_trunk(fea))
|
||||
fea = fea + trunk
|
||||
|
||||
fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest')))
|
||||
if self.sf==4:
|
||||
fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest')))
|
||||
out = self.conv_last(self.lrelu(self.HRconv(fea)))
|
||||
|
||||
return out
|
|
@ -5,31 +5,31 @@ import traceback
|
|||
import cv2
|
||||
import torch
|
||||
|
||||
from modules import shared, devices
|
||||
from modules.paths import script_path
|
||||
import modules.shared
|
||||
import modules.face_restoration
|
||||
from importlib import reload
|
||||
import modules.shared
|
||||
from modules import shared, devices, modelloader
|
||||
from modules.paths import script_path, models_path
|
||||
|
||||
# codeformer people made a choice to include modified basicsr librry to their projectwhich makes
|
||||
# it utterly impossiblr to use it alongside with other libraries that also use basicsr, like GFPGAN.
|
||||
# codeformer people made a choice to include modified basicsr library to their project which makes
|
||||
# it utterly impossible to use it alongside with other libraries that also use basicsr, like GFPGAN.
|
||||
# I am making a choice to include some files from codeformer to work around this issue.
|
||||
|
||||
pretrain_model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
|
||||
model_dir = "Codeformer"
|
||||
model_path = os.path.join(models_path, model_dir)
|
||||
model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
|
||||
|
||||
have_codeformer = False
|
||||
codeformer = None
|
||||
|
||||
def setup_codeformer():
|
||||
|
||||
def setup_model(dirname):
|
||||
global model_path
|
||||
if not os.path.exists(model_path):
|
||||
os.makedirs(model_path)
|
||||
|
||||
path = modules.paths.paths.get("CodeFormer", None)
|
||||
if path is None:
|
||||
return
|
||||
|
||||
|
||||
# both GFPGAN and CodeFormer use bascisr, one has it installed from pip the other uses its own
|
||||
#stored_sys_path = sys.path
|
||||
#sys.path = [path] + sys.path
|
||||
|
||||
try:
|
||||
from torchvision.transforms.functional import normalize
|
||||
from modules.codeformer.codeformer_arch import CodeFormer
|
||||
|
@ -44,18 +44,23 @@ def setup_codeformer():
|
|||
def name(self):
|
||||
return "CodeFormer"
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, dirname):
|
||||
self.net = None
|
||||
self.face_helper = None
|
||||
self.cmd_dir = dirname
|
||||
|
||||
def create_models(self):
|
||||
|
||||
if self.net is not None and self.face_helper is not None:
|
||||
self.net.to(devices.device_codeformer)
|
||||
return self.net, self.face_helper
|
||||
|
||||
model_paths = modelloader.load_models(model_path, model_url, self.cmd_dir, download_name='codeformer-v0.1.0.pth')
|
||||
if len(model_paths) != 0:
|
||||
ckpt_path = model_paths[0]
|
||||
else:
|
||||
print("Unable to load codeformer model.")
|
||||
return None, None
|
||||
net = net_class(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, connect_list=['32', '64', '128', '256']).to(devices.device_codeformer)
|
||||
ckpt_path = load_file_from_url(url=pretrain_model_url, model_dir=os.path.join(path, 'weights/CodeFormer'), progress=True)
|
||||
checkpoint = torch.load(ckpt_path)['params_ema']
|
||||
net.load_state_dict(checkpoint)
|
||||
net.eval()
|
||||
|
@ -74,6 +79,9 @@ def setup_codeformer():
|
|||
original_resolution = np_image.shape[0:2]
|
||||
|
||||
self.create_models()
|
||||
if self.net is None or self.face_helper is None:
|
||||
return np_image
|
||||
|
||||
self.face_helper.clean_all()
|
||||
self.face_helper.read_image(np_image)
|
||||
self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
|
||||
|
@ -114,7 +122,7 @@ def setup_codeformer():
|
|||
have_codeformer = True
|
||||
|
||||
global codeformer
|
||||
codeformer = FaceRestorerCodeFormer()
|
||||
codeformer = FaceRestorerCodeFormer(dirname)
|
||||
shared.face_restorers.append(codeformer)
|
||||
|
||||
except Exception:
|
||||
|
|
|
@ -32,10 +32,9 @@ def enable_tf32():
|
|||
|
||||
errors.run(enable_tf32, "Enabling TF32")
|
||||
|
||||
|
||||
device = get_optimal_device()
|
||||
device_codeformer = cpu if has_mps else device
|
||||
|
||||
dtype = torch.float16
|
||||
|
||||
def randn(seed, shape):
|
||||
# Pytorch currently doesn't handle setting randomness correctly when the metal backend is used.
|
||||
|
|
|
@ -1,26 +1,22 @@
|
|||
import os
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from basicsr.utils.download_util import load_file_from_url
|
||||
|
||||
import modules.esrgam_model_arch as arch
|
||||
from modules import shared
|
||||
from modules.shared import opts
|
||||
from modules import shared, modelloader, images
|
||||
from modules.devices import has_mps
|
||||
import modules.images
|
||||
from modules.paths import models_path
|
||||
from modules.upscaler import Upscaler, UpscalerData
|
||||
from modules.shared import opts
|
||||
|
||||
|
||||
def load_model(filename):
|
||||
def fix_model_layers(crt_model, pretrained_net):
|
||||
# this code is adapted from https://github.com/xinntao/ESRGAN
|
||||
pretrained_net = torch.load(filename, map_location='cpu' if has_mps else None)
|
||||
crt_model = arch.RRDBNet(3, 3, 64, 23, gc=32)
|
||||
|
||||
if 'conv_first.weight' in pretrained_net:
|
||||
crt_model.load_state_dict(pretrained_net)
|
||||
return crt_model
|
||||
return pretrained_net
|
||||
|
||||
if 'model.0.weight' not in pretrained_net:
|
||||
is_realesrgan = "params_ema" in pretrained_net and 'body.0.rdb1.conv1.weight' in pretrained_net["params_ema"]
|
||||
|
@ -72,9 +68,59 @@ def load_model(filename):
|
|||
crt_net['conv_last.weight'] = pretrained_net['model.10.weight']
|
||||
crt_net['conv_last.bias'] = pretrained_net['model.10.bias']
|
||||
|
||||
crt_model.load_state_dict(crt_net)
|
||||
crt_model.eval()
|
||||
return crt_model
|
||||
return crt_net
|
||||
|
||||
class UpscalerESRGAN(Upscaler):
|
||||
def __init__(self, dirname):
|
||||
self.name = "ESRGAN"
|
||||
self.model_url = "https://drive.google.com/u/0/uc?id=1TPrz5QKd8DHHt1k8SRtm6tMiPjz_Qene&export=download"
|
||||
self.model_name = "ESRGAN 4x"
|
||||
self.scalers = []
|
||||
self.user_path = dirname
|
||||
self.model_path = os.path.join(models_path, self.name)
|
||||
super().__init__()
|
||||
model_paths = self.find_models(ext_filter=[".pt", ".pth"])
|
||||
scalers = []
|
||||
if len(model_paths) == 0:
|
||||
scaler_data = UpscalerData(self.model_name, self.model_url, self, 4)
|
||||
scalers.append(scaler_data)
|
||||
for file in model_paths:
|
||||
if "http" in file:
|
||||
name = self.model_name
|
||||
else:
|
||||
name = modelloader.friendly_name(file)
|
||||
|
||||
scaler_data = UpscalerData(name, file, self, 4)
|
||||
self.scalers.append(scaler_data)
|
||||
|
||||
def do_upscale(self, img, selected_model):
|
||||
model = self.load_model(selected_model)
|
||||
if model is None:
|
||||
return img
|
||||
model.to(shared.device)
|
||||
img = esrgan_upscale(model, img)
|
||||
return img
|
||||
|
||||
def load_model(self, path: str):
|
||||
if "http" in path:
|
||||
filename = load_file_from_url(url=self.model_url, model_dir=self.model_path,
|
||||
file_name="%s.pth" % self.model_name,
|
||||
progress=True)
|
||||
else:
|
||||
filename = path
|
||||
if not os.path.exists(filename) or filename is None:
|
||||
print("Unable to load %s from %s" % (self.model_path, filename))
|
||||
return None
|
||||
|
||||
pretrained_net = torch.load(filename, map_location='cpu' if has_mps else None)
|
||||
crt_model = arch.RRDBNet(3, 3, 64, 23, gc=32)
|
||||
|
||||
pretrained_net = fix_model_layers(crt_model, pretrained_net)
|
||||
crt_model.load_state_dict(pretrained_net)
|
||||
crt_model.eval()
|
||||
|
||||
return crt_model
|
||||
|
||||
|
||||
def upscale_without_tiling(model, img):
|
||||
img = np.array(img)
|
||||
|
@ -95,7 +141,7 @@ def esrgan_upscale(model, img):
|
|||
if opts.ESRGAN_tile == 0:
|
||||
return upscale_without_tiling(model, img)
|
||||
|
||||
grid = modules.images.split_grid(img, opts.ESRGAN_tile, opts.ESRGAN_tile, opts.ESRGAN_tile_overlap)
|
||||
grid = images.split_grid(img, opts.ESRGAN_tile, opts.ESRGAN_tile, opts.ESRGAN_tile_overlap)
|
||||
newtiles = []
|
||||
scale_factor = 1
|
||||
|
||||
|
@ -110,32 +156,6 @@ def esrgan_upscale(model, img):
|
|||
newrow.append([x * scale_factor, w * scale_factor, output])
|
||||
newtiles.append([y * scale_factor, h * scale_factor, newrow])
|
||||
|
||||
newgrid = modules.images.Grid(newtiles, grid.tile_w * scale_factor, grid.tile_h * scale_factor, grid.image_w * scale_factor, grid.image_h * scale_factor, grid.overlap * scale_factor)
|
||||
output = modules.images.combine_grid(newgrid)
|
||||
newgrid = images.Grid(newtiles, grid.tile_w * scale_factor, grid.tile_h * scale_factor, grid.image_w * scale_factor, grid.image_h * scale_factor, grid.overlap * scale_factor)
|
||||
output = images.combine_grid(newgrid)
|
||||
return output
|
||||
|
||||
|
||||
class UpscalerESRGAN(modules.images.Upscaler):
|
||||
def __init__(self, filename, title):
|
||||
self.name = title
|
||||
self.model = load_model(filename)
|
||||
|
||||
def do_upscale(self, img):
|
||||
model = self.model.to(shared.device)
|
||||
img = esrgan_upscale(model, img)
|
||||
return img
|
||||
|
||||
|
||||
def load_models(dirname):
|
||||
for file in os.listdir(dirname):
|
||||
path = os.path.join(dirname, file)
|
||||
model_name, extension = os.path.splitext(file)
|
||||
|
||||
if extension != '.pt' and extension != '.pth':
|
||||
continue
|
||||
|
||||
try:
|
||||
modules.shared.sd_upscalers.append(UpscalerESRGAN(path, model_name))
|
||||
except Exception:
|
||||
print(f"Error loading ESRGAN model: {path}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
|
|
@ -40,6 +40,8 @@ def run_extras(extras_mode, image, image_folder, gfpgan_visibility, codeformer_v
|
|||
|
||||
outputs = []
|
||||
for image, image_name in zip(imageArr, imageNameArr):
|
||||
if image is None:
|
||||
return outputs, "Please select an input image.", ''
|
||||
existing_pnginfo = image.info or {}
|
||||
|
||||
image = image.convert("RGB")
|
||||
|
@ -74,7 +76,7 @@ def run_extras(extras_mode, image, image_folder, gfpgan_visibility, codeformer_v
|
|||
c = cached_images.get(key)
|
||||
if c is None:
|
||||
upscaler = shared.sd_upscalers[scaler_index]
|
||||
c = upscaler.upscale(image, image.width * resize, image.height * resize)
|
||||
c = upscaler.scaler.upscale(image, resize, upscaler.data_path)
|
||||
cached_images[key] = c
|
||||
|
||||
return c
|
||||
|
@ -189,9 +191,11 @@ def run_modelmerger(primary_model_name, secondary_model_name, interp_method, int
|
|||
if save_as_half:
|
||||
theta_0[key] = theta_0[key].half()
|
||||
|
||||
ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path
|
||||
|
||||
filename = primary_model_info.model_name + '_' + str(round(interp_amount, 2)) + '-' + secondary_model_info.model_name + '_' + str(round((float(1.0) - interp_amount), 2)) + '-' + interp_method.replace(" ", "_") + '-merged.ckpt'
|
||||
filename = filename if custom_name == '' else (custom_name + '.ckpt')
|
||||
output_modelname = os.path.join(shared.cmd_opts.ckpt_dir, filename)
|
||||
output_modelname = os.path.join(ckpt_dir, filename)
|
||||
|
||||
print(f"Saving to {output_modelname}...")
|
||||
torch.save(primary_model, output_modelname)
|
||||
|
|
|
@ -1,39 +1,25 @@
|
|||
import os
|
||||
import sys
|
||||
import traceback
|
||||
from glob import glob
|
||||
|
||||
from modules import shared, devices
|
||||
from modules.shared import cmd_opts
|
||||
from modules.paths import script_path
|
||||
import facexlib
|
||||
import gfpgan
|
||||
|
||||
import modules.face_restoration
|
||||
from modules import shared, devices, modelloader
|
||||
from modules.paths import models_path
|
||||
|
||||
|
||||
def gfpgan_model_path():
|
||||
from modules.shared import cmd_opts
|
||||
|
||||
filemask = 'GFPGAN*.pth'
|
||||
|
||||
if cmd_opts.gfpgan_model is not None:
|
||||
return cmd_opts.gfpgan_model
|
||||
|
||||
places = [script_path, '.', os.path.join(cmd_opts.gfpgan_dir, 'experiments/pretrained_models')]
|
||||
|
||||
filename = None
|
||||
for place in places:
|
||||
filename = next(iter(glob(os.path.join(place, filemask))), None)
|
||||
if filename is not None:
|
||||
break
|
||||
|
||||
return filename
|
||||
|
||||
|
||||
model_dir = "GFPGAN"
|
||||
user_path = None
|
||||
model_path = os.path.join(models_path, model_dir)
|
||||
model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
|
||||
have_gfpgan = False
|
||||
loaded_gfpgan_model = None
|
||||
|
||||
|
||||
def gfpgan():
|
||||
def gfpgann():
|
||||
global loaded_gfpgan_model
|
||||
|
||||
global model_path
|
||||
if loaded_gfpgan_model is not None:
|
||||
loaded_gfpgan_model.gfpgan.to(shared.device)
|
||||
return loaded_gfpgan_model
|
||||
|
@ -41,7 +27,16 @@ def gfpgan():
|
|||
if gfpgan_constructor is None:
|
||||
return None
|
||||
|
||||
model = gfpgan_constructor(model_path=gfpgan_model_path() or 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth', upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None)
|
||||
models = modelloader.load_models(model_path, model_url, user_path, ext_filter="GFPGAN")
|
||||
if len(models) == 1 and "http" in models[0]:
|
||||
model_file = models[0]
|
||||
elif len(models) != 0:
|
||||
latest_file = max(models, key=os.path.getctime)
|
||||
model_file = latest_file
|
||||
else:
|
||||
print("Unable to load gfpgan model!")
|
||||
return None
|
||||
model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None)
|
||||
model.gfpgan.to(shared.device)
|
||||
loaded_gfpgan_model = model
|
||||
|
||||
|
@ -49,8 +44,9 @@ def gfpgan():
|
|||
|
||||
|
||||
def gfpgan_fix_faces(np_image):
|
||||
model = gfpgan()
|
||||
|
||||
model = gfpgann()
|
||||
if model is None:
|
||||
return np_image
|
||||
np_image_bgr = np_image[:, :, ::-1]
|
||||
cropped_faces, restored_faces, gfpgan_output_bgr = model.enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True)
|
||||
np_image = gfpgan_output_bgr[:, :, ::-1]
|
||||
|
@ -61,21 +57,39 @@ def gfpgan_fix_faces(np_image):
|
|||
return np_image
|
||||
|
||||
|
||||
have_gfpgan = False
|
||||
gfpgan_constructor = None
|
||||
|
||||
def setup_gfpgan():
|
||||
|
||||
def setup_model(dirname):
|
||||
global model_path
|
||||
if not os.path.exists(model_path):
|
||||
os.makedirs(model_path)
|
||||
|
||||
try:
|
||||
gfpgan_model_path()
|
||||
|
||||
if os.path.exists(cmd_opts.gfpgan_dir):
|
||||
sys.path.append(os.path.abspath(cmd_opts.gfpgan_dir))
|
||||
from gfpgan import GFPGANer
|
||||
|
||||
from facexlib import detection, parsing
|
||||
global user_path
|
||||
global have_gfpgan
|
||||
have_gfpgan = True
|
||||
|
||||
global gfpgan_constructor
|
||||
|
||||
load_file_from_url_orig = gfpgan.utils.load_file_from_url
|
||||
facex_load_file_from_url_orig = facexlib.detection.load_file_from_url
|
||||
facex_load_file_from_url_orig2 = facexlib.parsing.load_file_from_url
|
||||
|
||||
def my_load_file_from_url(**kwargs):
|
||||
return load_file_from_url_orig(**dict(kwargs, model_dir=model_path))
|
||||
|
||||
def facex_load_file_from_url(**kwargs):
|
||||
return facex_load_file_from_url_orig(**dict(kwargs, save_dir=model_path, model_dir=None))
|
||||
|
||||
def facex_load_file_from_url2(**kwargs):
|
||||
return facex_load_file_from_url_orig2(**dict(kwargs, save_dir=model_path, model_dir=None))
|
||||
|
||||
gfpgan.utils.load_file_from_url = my_load_file_from_url
|
||||
facexlib.detection.load_file_from_url = facex_load_file_from_url
|
||||
facexlib.parsing.load_file_from_url = facex_load_file_from_url2
|
||||
user_path = dirname
|
||||
have_gfpgan = True
|
||||
gfpgan_constructor = GFPGANer
|
||||
|
||||
class FaceRestorerGFPGAN(modules.face_restoration.FaceRestoration):
|
||||
|
@ -84,7 +98,7 @@ def setup_gfpgan():
|
|||
|
||||
def restore(self, np_image):
|
||||
np_image_bgr = np_image[:, :, ::-1]
|
||||
cropped_faces, restored_faces, gfpgan_output_bgr = gfpgan().enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True)
|
||||
cropped_faces, restored_faces, gfpgan_output_bgr = gfpgann().enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True)
|
||||
np_image = gfpgan_output_bgr[:, :, ::-1]
|
||||
|
||||
return np_image
|
||||
|
|
|
@ -11,7 +11,6 @@ from PIL import Image, ImageFont, ImageDraw, PngImagePlugin
|
|||
from fonts.ttf import Roboto
|
||||
import string
|
||||
|
||||
import modules.shared
|
||||
from modules import sd_samplers, shared
|
||||
from modules.shared import opts, cmd_opts
|
||||
|
||||
|
@ -52,8 +51,8 @@ def split_grid(image, tile_w=512, tile_h=512, overlap=64):
|
|||
cols = math.ceil((w - overlap) / non_overlap_width)
|
||||
rows = math.ceil((h - overlap) / non_overlap_height)
|
||||
|
||||
dx = (w - tile_w) / (cols-1) if cols > 1 else 0
|
||||
dy = (h - tile_h) / (rows-1) if rows > 1 else 0
|
||||
dx = (w - tile_w) / (cols - 1) if cols > 1 else 0
|
||||
dy = (h - tile_h) / (rows - 1) if rows > 1 else 0
|
||||
|
||||
grid = Grid([], tile_w, tile_h, w, h, overlap)
|
||||
for row in range(rows):
|
||||
|
@ -67,7 +66,7 @@ def split_grid(image, tile_w=512, tile_h=512, overlap=64):
|
|||
for col in range(cols):
|
||||
x = int(col * dx)
|
||||
|
||||
if x+tile_w >= w:
|
||||
if x + tile_w >= w:
|
||||
x = w - tile_w
|
||||
|
||||
tile = image.crop((x, y, x + tile_w, y + tile_h))
|
||||
|
@ -132,7 +131,7 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts):
|
|||
drawing.multiline_text((draw_x, draw_y + line.size[1] / 2), line.text, font=fnt, fill=color_active if line.is_active else color_inactive, anchor="mm", align="center")
|
||||
|
||||
if not line.is_active:
|
||||
drawing.line((draw_x - line.size[0]//2, draw_y + line.size[1]//2, draw_x + line.size[0]//2, draw_y + line.size[1]//2), fill=color_inactive, width=4)
|
||||
drawing.line((draw_x - line.size[0] // 2, draw_y + line.size[1] // 2, draw_x + line.size[0] // 2, draw_y + line.size[1] // 2), fill=color_inactive, width=4)
|
||||
|
||||
draw_y += line.size[1] + line_spacing
|
||||
|
||||
|
@ -171,7 +170,8 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts):
|
|||
line.size = (bbox[2] - bbox[0], bbox[3] - bbox[1])
|
||||
|
||||
hor_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing for lines in hor_texts]
|
||||
ver_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing * len(lines) for lines in ver_texts]
|
||||
ver_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing * len(lines) for lines in
|
||||
ver_texts]
|
||||
|
||||
pad_top = max(hor_text_heights) + line_spacing * 2
|
||||
|
||||
|
@ -213,8 +213,19 @@ def resize_image(resize_mode, im, width, height):
|
|||
if opts.upscaler_for_img2img is None or opts.upscaler_for_img2img == "None" or im.mode == 'L':
|
||||
return im.resize((w, h), resample=LANCZOS)
|
||||
|
||||
upscaler = [x for x in shared.sd_upscalers if x.name == opts.upscaler_for_img2img][0]
|
||||
return upscaler.upscale(im, w, h)
|
||||
scale = max(w / im.width, h / im.height)
|
||||
|
||||
if scale > 1.0:
|
||||
upscalers = [x for x in shared.sd_upscalers if x.name == opts.upscaler_for_img2img]
|
||||
assert len(upscalers) > 0, f"could not find upscaler named {opts.upscaler_for_img2img}"
|
||||
|
||||
upscaler = upscalers[0]
|
||||
im = upscaler.scaler.upscale(im, scale, upscaler.data_path)
|
||||
|
||||
if im.width != w or im.height != h:
|
||||
im = im.resize((w, h), resample=LANCZOS)
|
||||
|
||||
return im
|
||||
|
||||
if resize_mode == 0:
|
||||
res = resize(im, width, height)
|
||||
|
@ -256,7 +267,7 @@ def resize_image(resize_mode, im, width, height):
|
|||
invalid_filename_chars = '<>:"/\\|?*\n'
|
||||
invalid_filename_prefix = ' '
|
||||
invalid_filename_postfix = ' .'
|
||||
re_nonletters = re.compile(r'[\s'+string.punctuation+']+')
|
||||
re_nonletters = re.compile(r'[\s' + string.punctuation + ']+')
|
||||
max_filename_part_length = 128
|
||||
|
||||
|
||||
|
@ -278,6 +289,16 @@ def apply_filename_pattern(x, p, seed, prompt):
|
|||
|
||||
if prompt is not None:
|
||||
x = x.replace("[prompt]", sanitize_filename_part(prompt))
|
||||
if "[prompt_no_styles]" in x:
|
||||
prompt_no_style = prompt
|
||||
for style in shared.prompt_styles.get_style_prompts(p.styles):
|
||||
if len(style) > 0:
|
||||
style_parts = [y for y in style.split("{prompt}")]
|
||||
for part in style_parts:
|
||||
prompt_no_style = prompt_no_style.replace(part, "").replace(", ,", ",").strip().strip(',')
|
||||
prompt_no_style = prompt_no_style.replace(style, "").strip().strip(',').strip()
|
||||
x = x.replace("[prompt_no_styles]", sanitize_filename_part(prompt_no_style, replace_spaces=False))
|
||||
|
||||
x = x.replace("[prompt_spaces]", sanitize_filename_part(prompt, replace_spaces=False))
|
||||
if "[prompt_words]" in x:
|
||||
words = [x for x in re_nonletters.split(prompt or "") if len(x) > 0]
|
||||
|
@ -290,10 +311,12 @@ def apply_filename_pattern(x, p, seed, prompt):
|
|||
x = x.replace("[cfg]", str(p.cfg_scale))
|
||||
x = x.replace("[width]", str(p.width))
|
||||
x = x.replace("[height]", str(p.height))
|
||||
|
||||
#currently disabled if using the save button, will work otherwise
|
||||
# if enabled it will cause a bug because styles is not included in the save_files data dictionary
|
||||
if hasattr(p, "styles"):
|
||||
x = x.replace("[styles]", sanitize_filename_part(", ".join(p.styles), replace_spaces=False))
|
||||
x = x.replace("[styles]", sanitize_filename_part(", ".join([x for x in p.styles if not x == "None"]), replace_spaces=False))
|
||||
|
||||
x = x.replace("[sampler]", sanitize_filename_part(sd_samplers.samplers[p.sampler_index].name, replace_spaces=False))
|
||||
|
||||
x = x.replace("[model_hash]", shared.sd_model.sd_model_hash)
|
||||
|
@ -306,6 +329,7 @@ def apply_filename_pattern(x, p, seed, prompt):
|
|||
|
||||
return x
|
||||
|
||||
|
||||
def get_next_sequence_number(path, basename):
|
||||
"""
|
||||
Determines and returns the next sequence number to use when saving an image in the specified directory.
|
||||
|
@ -319,7 +343,7 @@ def get_next_sequence_number(path, basename):
|
|||
prefix_length = len(basename)
|
||||
for p in os.listdir(path):
|
||||
if p.startswith(basename):
|
||||
l = os.path.splitext(p[prefix_length:])[0].split('-') #splits the filename (removing the basename first if one is defined, so the sequence number is always the first element)
|
||||
l = os.path.splitext(p[prefix_length:])[0].split('-') # splits the filename (removing the basename first if one is defined, so the sequence number is always the first element)
|
||||
try:
|
||||
result = max(int(l[0]), result)
|
||||
except ValueError:
|
||||
|
@ -327,6 +351,7 @@ def get_next_sequence_number(path, basename):
|
|||
|
||||
return result + 1
|
||||
|
||||
|
||||
def save_image(image, path, basename, seed=None, prompt=None, extension='png', info=None, short_filename=False, no_prompt=False, grid=False, pnginfo_section_name='parameters', p=None, existing_info=None, forced_filename=None, suffix=""):
|
||||
if short_filename or prompt is None or seed is None:
|
||||
file_decoration = ""
|
||||
|
@ -364,7 +389,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
|
|||
fullfn = "a.png"
|
||||
fullfn_without_extension = "a"
|
||||
for i in range(500):
|
||||
fn = f"{basecount+i:05}" if basename == '' else f"{basename}-{basecount+i:04}"
|
||||
fn = f"{basecount + i:05}" if basename == '' else f"{basename}-{basecount + i:04}"
|
||||
fullfn = os.path.join(path, f"{fn}{file_decoration}.{extension}")
|
||||
fullfn_without_extension = os.path.join(path, f"{fn}{file_decoration}")
|
||||
if not os.path.exists(fullfn):
|
||||
|
@ -406,31 +431,3 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
|
|||
file.write(info + "\n")
|
||||
|
||||
|
||||
class Upscaler:
|
||||
name = "Lanczos"
|
||||
|
||||
def do_upscale(self, img):
|
||||
return img
|
||||
|
||||
def upscale(self, img, w, h):
|
||||
for i in range(3):
|
||||
if img.width >= w and img.height >= h:
|
||||
break
|
||||
|
||||
img = self.do_upscale(img)
|
||||
|
||||
if img.width != w or img.height != h:
|
||||
img = img.resize((int(w), int(h)), resample=LANCZOS)
|
||||
|
||||
return img
|
||||
|
||||
|
||||
class UpscalerNone(Upscaler):
|
||||
name = "None"
|
||||
|
||||
def upscale(self, img, w, h):
|
||||
return img
|
||||
|
||||
|
||||
modules.shared.sd_upscalers.append(UpscalerNone())
|
||||
modules.shared.sd_upscalers.append(Upscaler())
|
||||
|
|
|
@ -1,67 +1,56 @@
|
|||
import os
|
||||
import sys
|
||||
import traceback
|
||||
from collections import namedtuple
|
||||
|
||||
from basicsr.utils.download_util import load_file_from_url
|
||||
|
||||
import modules.images
|
||||
from modules.upscaler import Upscaler, UpscalerData
|
||||
from modules.ldsr_model_arch import LDSR
|
||||
from modules import shared
|
||||
from modules.paths import script_path
|
||||
|
||||
LDSRModelInfo = namedtuple("LDSRModelInfo", ["name", "location", "model", "netscale"])
|
||||
|
||||
ldsr_models = []
|
||||
have_ldsr = False
|
||||
LDSR_obj = None
|
||||
from modules.paths import models_path
|
||||
|
||||
|
||||
class UpscalerLDSR(modules.images.Upscaler):
|
||||
def __init__(self, steps):
|
||||
self.steps = steps
|
||||
class UpscalerLDSR(Upscaler):
|
||||
def __init__(self, user_path):
|
||||
self.name = "LDSR"
|
||||
self.model_path = os.path.join(models_path, self.name)
|
||||
self.user_path = user_path
|
||||
self.model_url = "https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1"
|
||||
self.yaml_url = "https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1"
|
||||
super().__init__()
|
||||
scaler_data = UpscalerData("LDSR", None, self)
|
||||
self.scalers = [scaler_data]
|
||||
|
||||
def do_upscale(self, img):
|
||||
return upscale_with_ldsr(img)
|
||||
def load_model(self, path: str):
|
||||
# Remove incorrect project.yaml file if too big
|
||||
yaml_path = os.path.join(self.model_path, "project.yaml")
|
||||
old_model_path = os.path.join(self.model_path, "model.pth")
|
||||
new_model_path = os.path.join(self.model_path, "model.ckpt")
|
||||
if os.path.exists(yaml_path):
|
||||
statinfo = os.stat(yaml_path)
|
||||
if statinfo.st_size >= 10485760:
|
||||
print("Removing invalid LDSR YAML file.")
|
||||
os.remove(yaml_path)
|
||||
if os.path.exists(old_model_path):
|
||||
print("Renaming model from model.pth to model.ckpt")
|
||||
os.rename(old_model_path, new_model_path)
|
||||
model = load_file_from_url(url=self.model_url, model_dir=self.model_path,
|
||||
file_name="model.ckpt", progress=True)
|
||||
yaml = load_file_from_url(url=self.yaml_url, model_dir=self.model_path,
|
||||
file_name="project.yaml", progress=True)
|
||||
|
||||
try:
|
||||
return LDSR(model, yaml)
|
||||
|
||||
def add_lsdr():
|
||||
modules.shared.sd_upscalers.append(UpscalerLDSR(100))
|
||||
except Exception:
|
||||
print("Error importing LDSR:", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
return None
|
||||
|
||||
|
||||
def setup_ldsr():
|
||||
path = modules.paths.paths.get("LDSR", None)
|
||||
if path is None:
|
||||
return
|
||||
global have_ldsr
|
||||
global LDSR_obj
|
||||
try:
|
||||
from LDSR import LDSR
|
||||
model_url = "https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1"
|
||||
yaml_url = "https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1"
|
||||
repo_path = 'latent-diffusion/experiments/pretrained_models/'
|
||||
model_path = load_file_from_url(url=model_url, model_dir=os.path.join("repositories", repo_path),
|
||||
progress=True, file_name="model.chkpt")
|
||||
yaml_path = load_file_from_url(url=yaml_url, model_dir=os.path.join("repositories", repo_path),
|
||||
progress=True, file_name="project.yaml")
|
||||
have_ldsr = True
|
||||
LDSR_obj = LDSR(model_path, yaml_path)
|
||||
|
||||
|
||||
except Exception:
|
||||
print("Error importing LDSR:", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
have_ldsr = False
|
||||
|
||||
|
||||
def upscale_with_ldsr(image):
|
||||
setup_ldsr()
|
||||
if not have_ldsr or LDSR_obj is None:
|
||||
return image
|
||||
|
||||
ddim_steps = shared.opts.ldsr_steps
|
||||
pre_scale = shared.opts.ldsr_pre_down
|
||||
post_scale = shared.opts.ldsr_post_down
|
||||
|
||||
image = LDSR_obj.super_resolution(image, ddim_steps, pre_scale, post_scale)
|
||||
return image
|
||||
def do_upscale(self, img, path):
|
||||
ldsr = self.load_model(path)
|
||||
if ldsr is None:
|
||||
print("NO LDSR!")
|
||||
return img
|
||||
ddim_steps = shared.opts.ldsr_steps
|
||||
return ldsr.super_resolution(img, ddim_steps, self.scale)
|
||||
|
|
222
modules/ldsr_model_arch.py
Normal file
222
modules/ldsr_model_arch.py
Normal file
|
@ -0,0 +1,222 @@
|
|||
import gc
|
||||
import time
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision
|
||||
from PIL import Image
|
||||
from einops import rearrange, repeat
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
from ldm.util import instantiate_from_config, ismap
|
||||
|
||||
warnings.filterwarnings("ignore", category=UserWarning)
|
||||
|
||||
|
||||
# Create LDSR Class
|
||||
class LDSR:
|
||||
def load_model_from_config(self, half_attention):
|
||||
print(f"Loading model from {self.modelPath}")
|
||||
pl_sd = torch.load(self.modelPath, map_location="cpu")
|
||||
sd = pl_sd["state_dict"]
|
||||
config = OmegaConf.load(self.yamlPath)
|
||||
model = instantiate_from_config(config.model)
|
||||
model.load_state_dict(sd, strict=False)
|
||||
model.cuda()
|
||||
if half_attention:
|
||||
model = model.half()
|
||||
|
||||
model.eval()
|
||||
return {"model": model}
|
||||
|
||||
def __init__(self, model_path, yaml_path):
|
||||
self.modelPath = model_path
|
||||
self.yamlPath = yaml_path
|
||||
|
||||
@staticmethod
|
||||
def run(model, selected_path, custom_steps, eta):
|
||||
example = get_cond(selected_path)
|
||||
|
||||
n_runs = 1
|
||||
guider = None
|
||||
ckwargs = None
|
||||
ddim_use_x0_pred = False
|
||||
temperature = 1.
|
||||
eta = eta
|
||||
custom_shape = None
|
||||
|
||||
height, width = example["image"].shape[1:3]
|
||||
split_input = height >= 128 and width >= 128
|
||||
|
||||
if split_input:
|
||||
ks = 128
|
||||
stride = 64
|
||||
vqf = 4 #
|
||||
model.split_input_params = {"ks": (ks, ks), "stride": (stride, stride),
|
||||
"vqf": vqf,
|
||||
"patch_distributed_vq": True,
|
||||
"tie_braker": False,
|
||||
"clip_max_weight": 0.5,
|
||||
"clip_min_weight": 0.01,
|
||||
"clip_max_tie_weight": 0.5,
|
||||
"clip_min_tie_weight": 0.01}
|
||||
else:
|
||||
if hasattr(model, "split_input_params"):
|
||||
delattr(model, "split_input_params")
|
||||
|
||||
x_t = None
|
||||
logs = None
|
||||
for n in range(n_runs):
|
||||
if custom_shape is not None:
|
||||
x_t = torch.randn(1, custom_shape[1], custom_shape[2], custom_shape[3]).to(model.device)
|
||||
x_t = repeat(x_t, '1 c h w -> b c h w', b=custom_shape[0])
|
||||
|
||||
logs = make_convolutional_sample(example, model,
|
||||
custom_steps=custom_steps,
|
||||
eta=eta, quantize_x0=False,
|
||||
custom_shape=custom_shape,
|
||||
temperature=temperature, noise_dropout=0.,
|
||||
corrector=guider, corrector_kwargs=ckwargs, x_T=x_t,
|
||||
ddim_use_x0_pred=ddim_use_x0_pred
|
||||
)
|
||||
return logs
|
||||
|
||||
def super_resolution(self, image, steps=100, target_scale=2, half_attention=False):
|
||||
model = self.load_model_from_config(half_attention)
|
||||
|
||||
# Run settings
|
||||
diffusion_steps = int(steps)
|
||||
eta = 1.0
|
||||
|
||||
down_sample_method = 'Lanczos'
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
im_og = image
|
||||
width_og, height_og = im_og.size
|
||||
# If we can adjust the max upscale size, then the 4 below should be our variable
|
||||
down_sample_rate = target_scale / 4
|
||||
wd = width_og * down_sample_rate
|
||||
hd = height_og * down_sample_rate
|
||||
width_downsampled_pre = int(wd)
|
||||
height_downsampled_pre = int(hd)
|
||||
|
||||
if down_sample_rate != 1:
|
||||
print(
|
||||
f'Downsampling from [{width_og}, {height_og}] to [{width_downsampled_pre}, {height_downsampled_pre}]')
|
||||
im_og = im_og.resize((width_downsampled_pre, height_downsampled_pre), Image.LANCZOS)
|
||||
else:
|
||||
print(f"Down sample rate is 1 from {target_scale} / 4 (Not downsampling)")
|
||||
logs = self.run(model["model"], im_og, diffusion_steps, eta)
|
||||
|
||||
sample = logs["sample"]
|
||||
sample = sample.detach().cpu()
|
||||
sample = torch.clamp(sample, -1., 1.)
|
||||
sample = (sample + 1.) / 2. * 255
|
||||
sample = sample.numpy().astype(np.uint8)
|
||||
sample = np.transpose(sample, (0, 2, 3, 1))
|
||||
a = Image.fromarray(sample[0])
|
||||
|
||||
del model
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
return a
|
||||
|
||||
|
||||
def get_cond(selected_path):
|
||||
example = dict()
|
||||
up_f = 4
|
||||
c = selected_path.convert('RGB')
|
||||
c = torch.unsqueeze(torchvision.transforms.ToTensor()(c), 0)
|
||||
c_up = torchvision.transforms.functional.resize(c, size=[up_f * c.shape[2], up_f * c.shape[3]],
|
||||
antialias=True)
|
||||
c_up = rearrange(c_up, '1 c h w -> 1 h w c')
|
||||
c = rearrange(c, '1 c h w -> 1 h w c')
|
||||
c = 2. * c - 1.
|
||||
|
||||
c = c.to(torch.device("cuda"))
|
||||
example["LR_image"] = c
|
||||
example["image"] = c_up
|
||||
|
||||
return example
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convsample_ddim(model, cond, steps, shape, eta=1.0, callback=None, normals_sequence=None,
|
||||
mask=None, x0=None, quantize_x0=False, temperature=1., score_corrector=None,
|
||||
corrector_kwargs=None, x_t=None
|
||||
):
|
||||
ddim = DDIMSampler(model)
|
||||
bs = shape[0]
|
||||
shape = shape[1:]
|
||||
print(f"Sampling with eta = {eta}; steps: {steps}")
|
||||
samples, intermediates = ddim.sample(steps, batch_size=bs, shape=shape, conditioning=cond, callback=callback,
|
||||
normals_sequence=normals_sequence, quantize_x0=quantize_x0, eta=eta,
|
||||
mask=mask, x0=x0, temperature=temperature, verbose=False,
|
||||
score_corrector=score_corrector,
|
||||
corrector_kwargs=corrector_kwargs, x_t=x_t)
|
||||
|
||||
return samples, intermediates
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def make_convolutional_sample(batch, model, custom_steps=None, eta=1.0, quantize_x0=False, custom_shape=None, temperature=1., noise_dropout=0., corrector=None,
|
||||
corrector_kwargs=None, x_T=None, ddim_use_x0_pred=False):
|
||||
log = dict()
|
||||
|
||||
z, c, x, xrec, xc = model.get_input(batch, model.first_stage_key,
|
||||
return_first_stage_outputs=True,
|
||||
force_c_encode=not (hasattr(model, 'split_input_params')
|
||||
and model.cond_stage_key == 'coordinates_bbox'),
|
||||
return_original_cond=True)
|
||||
|
||||
if custom_shape is not None:
|
||||
z = torch.randn(custom_shape)
|
||||
print(f"Generating {custom_shape[0]} samples of shape {custom_shape[1:]}")
|
||||
|
||||
z0 = None
|
||||
|
||||
log["input"] = x
|
||||
log["reconstruction"] = xrec
|
||||
|
||||
if ismap(xc):
|
||||
log["original_conditioning"] = model.to_rgb(xc)
|
||||
if hasattr(model, 'cond_stage_key'):
|
||||
log[model.cond_stage_key] = model.to_rgb(xc)
|
||||
|
||||
else:
|
||||
log["original_conditioning"] = xc if xc is not None else torch.zeros_like(x)
|
||||
if model.cond_stage_model:
|
||||
log[model.cond_stage_key] = xc if xc is not None else torch.zeros_like(x)
|
||||
if model.cond_stage_key == 'class_label':
|
||||
log[model.cond_stage_key] = xc[model.cond_stage_key]
|
||||
|
||||
with model.ema_scope("Plotting"):
|
||||
t0 = time.time()
|
||||
|
||||
sample, intermediates = convsample_ddim(model, c, steps=custom_steps, shape=z.shape,
|
||||
eta=eta,
|
||||
quantize_x0=quantize_x0, mask=None, x0=z0,
|
||||
temperature=temperature, score_corrector=corrector, corrector_kwargs=corrector_kwargs,
|
||||
x_t=x_T)
|
||||
t1 = time.time()
|
||||
|
||||
if ddim_use_x0_pred:
|
||||
sample = intermediates['pred_x0'][-1]
|
||||
|
||||
x_sample = model.decode_first_stage(sample)
|
||||
|
||||
try:
|
||||
x_sample_noquant = model.decode_first_stage(sample, force_not_quantize=True)
|
||||
log["sample_noquant"] = x_sample_noquant
|
||||
log["sample_diff"] = torch.abs(x_sample_noquant - x_sample)
|
||||
except:
|
||||
pass
|
||||
|
||||
log["sample"] = x_sample
|
||||
log["time"] = t1 - t0
|
||||
|
||||
return log
|
140
modules/modelloader.py
Normal file
140
modules/modelloader.py
Normal file
|
@ -0,0 +1,140 @@
|
|||
import glob
|
||||
import os
|
||||
import shutil
|
||||
import importlib
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from basicsr.utils.download_util import load_file_from_url
|
||||
|
||||
from modules import shared
|
||||
from modules.upscaler import Upscaler
|
||||
from modules.paths import script_path, models_path
|
||||
|
||||
|
||||
def load_models(model_path: str, model_url: str = None, command_path: str = None, ext_filter=None, download_name=None) -> list:
|
||||
"""
|
||||
A one-and done loader to try finding the desired models in specified directories.
|
||||
|
||||
@param download_name: Specify to download from model_url immediately.
|
||||
@param model_url: If no other models are found, this will be downloaded on upscale.
|
||||
@param model_path: The location to store/find models in.
|
||||
@param command_path: A command-line argument to search for models in first.
|
||||
@param ext_filter: An optional list of filename extensions to filter by
|
||||
@return: A list of paths containing the desired model(s)
|
||||
"""
|
||||
output = []
|
||||
|
||||
if ext_filter is None:
|
||||
ext_filter = []
|
||||
|
||||
try:
|
||||
places = []
|
||||
|
||||
if command_path is not None and command_path != model_path:
|
||||
pretrained_path = os.path.join(command_path, 'experiments/pretrained_models')
|
||||
if os.path.exists(pretrained_path):
|
||||
print(f"Appending path: {pretrained_path}")
|
||||
places.append(pretrained_path)
|
||||
elif os.path.exists(command_path):
|
||||
places.append(command_path)
|
||||
|
||||
places.append(model_path)
|
||||
|
||||
for place in places:
|
||||
if os.path.exists(place):
|
||||
for file in glob.iglob(place + '**/**', recursive=True):
|
||||
full_path = file
|
||||
if os.path.isdir(full_path):
|
||||
continue
|
||||
if len(ext_filter) != 0:
|
||||
model_name, extension = os.path.splitext(file)
|
||||
if extension not in ext_filter:
|
||||
continue
|
||||
if file not in output:
|
||||
output.append(full_path)
|
||||
|
||||
if model_url is not None and len(output) == 0:
|
||||
if download_name is not None:
|
||||
dl = load_file_from_url(model_url, model_path, True, download_name)
|
||||
output.append(dl)
|
||||
else:
|
||||
output.append(model_url)
|
||||
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def friendly_name(file: str):
|
||||
if "http" in file:
|
||||
file = urlparse(file).path
|
||||
|
||||
file = os.path.basename(file)
|
||||
model_name, extension = os.path.splitext(file)
|
||||
return model_name
|
||||
|
||||
|
||||
def cleanup_models():
|
||||
# This code could probably be more efficient if we used a tuple list or something to store the src/destinations
|
||||
# and then enumerate that, but this works for now. In the future, it'd be nice to just have every "model" scaler
|
||||
# somehow auto-register and just do these things...
|
||||
root_path = script_path
|
||||
src_path = models_path
|
||||
dest_path = os.path.join(models_path, "Stable-diffusion")
|
||||
move_files(src_path, dest_path, ".ckpt")
|
||||
src_path = os.path.join(root_path, "ESRGAN")
|
||||
dest_path = os.path.join(models_path, "ESRGAN")
|
||||
move_files(src_path, dest_path)
|
||||
src_path = os.path.join(root_path, "gfpgan")
|
||||
dest_path = os.path.join(models_path, "GFPGAN")
|
||||
move_files(src_path, dest_path)
|
||||
src_path = os.path.join(root_path, "SwinIR")
|
||||
dest_path = os.path.join(models_path, "SwinIR")
|
||||
move_files(src_path, dest_path)
|
||||
src_path = os.path.join(root_path, "repositories/latent-diffusion/experiments/pretrained_models/")
|
||||
dest_path = os.path.join(models_path, "LDSR")
|
||||
move_files(src_path, dest_path)
|
||||
|
||||
|
||||
def move_files(src_path: str, dest_path: str, ext_filter: str = None):
|
||||
try:
|
||||
if not os.path.exists(dest_path):
|
||||
os.makedirs(dest_path)
|
||||
if os.path.exists(src_path):
|
||||
for file in os.listdir(src_path):
|
||||
fullpath = os.path.join(src_path, file)
|
||||
if os.path.isfile(fullpath):
|
||||
if ext_filter is not None:
|
||||
if ext_filter not in file:
|
||||
continue
|
||||
print(f"Moving {file} from {src_path} to {dest_path}.")
|
||||
try:
|
||||
shutil.move(fullpath, dest_path)
|
||||
except:
|
||||
pass
|
||||
if len(os.listdir(src_path)) == 0:
|
||||
print(f"Removing empty folder: {src_path}")
|
||||
shutil.rmtree(src_path, True)
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
def load_upscalers():
|
||||
datas = []
|
||||
for cls in Upscaler.__subclasses__():
|
||||
name = cls.__name__
|
||||
module_name = cls.__module__
|
||||
module = importlib.import_module(module_name)
|
||||
class_ = getattr(module, name)
|
||||
cmd_name = f"{name.lower().replace('upscaler', '')}-models-path"
|
||||
opt_string = None
|
||||
try:
|
||||
opt_string = shared.opts.__getattr__(cmd_name)
|
||||
except:
|
||||
pass
|
||||
scaler = class_(opt_string)
|
||||
for child in scaler.scalers:
|
||||
datas.append(child)
|
||||
|
||||
shared.sd_upscalers = datas
|
|
@ -3,9 +3,10 @@ import os
|
|||
import sys
|
||||
|
||||
script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
models_path = os.path.join(script_path, "models")
|
||||
sys.path.insert(0, script_path)
|
||||
|
||||
# search for directory of stable diffsuion in following palces
|
||||
# search for directory of stable diffusion in following places
|
||||
sd_path = None
|
||||
possible_sd_paths = [os.path.join(script_path, 'repositories/stable-diffusion'), '.', os.path.dirname(script_path)]
|
||||
for possible_sd_path in possible_sd_paths:
|
||||
|
@ -15,21 +16,24 @@ for possible_sd_path in possible_sd_paths:
|
|||
assert sd_path is not None, "Couldn't find Stable Diffusion in any of: " + str(possible_sd_paths)
|
||||
|
||||
path_dirs = [
|
||||
(sd_path, 'ldm', 'Stable Diffusion'),
|
||||
(os.path.join(sd_path, '../taming-transformers'), 'taming', 'Taming Transformers'),
|
||||
(os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer'),
|
||||
(os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP'),
|
||||
(os.path.join(sd_path, '../latent-diffusion'), 'LDSR.py', 'LDSR'),
|
||||
(os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion'),
|
||||
(sd_path, 'ldm', 'Stable Diffusion', []),
|
||||
(os.path.join(sd_path, '../taming-transformers'), 'taming', 'Taming Transformers', []),
|
||||
(os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer', []),
|
||||
(os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []),
|
||||
(os.path.join(sd_path, '../latent-diffusion'), 'LDSR.py', 'LDSR', []),
|
||||
(os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]),
|
||||
]
|
||||
|
||||
paths = {}
|
||||
|
||||
for d, must_exist, what in path_dirs:
|
||||
for d, must_exist, what, options in path_dirs:
|
||||
must_exist_path = os.path.abspath(os.path.join(script_path, d, must_exist))
|
||||
if not os.path.exists(must_exist_path):
|
||||
print(f"Warning: {what} not found at path {must_exist_path}", file=sys.stderr)
|
||||
else:
|
||||
d = os.path.abspath(d)
|
||||
sys.path.append(d)
|
||||
if "atstart" in options:
|
||||
sys.path.insert(0, d)
|
||||
else:
|
||||
sys.path.append(d)
|
||||
paths[what] = d
|
||||
|
|
|
@ -56,7 +56,7 @@ class StableDiffusionProcessing:
|
|||
self.prompt: str = prompt
|
||||
self.prompt_for_display: str = None
|
||||
self.negative_prompt: str = (negative_prompt or "")
|
||||
self.styles: str = styles
|
||||
self.styles: list = styles or []
|
||||
self.seed: int = seed
|
||||
self.subseed: int = subseed
|
||||
self.subseed_strength: float = subseed_strength
|
||||
|
@ -79,7 +79,7 @@ class StableDiffusionProcessing:
|
|||
self.paste_to = None
|
||||
self.color_corrections = None
|
||||
self.denoising_strength: float = 0
|
||||
|
||||
self.sampler_noise_scheduler_override = None
|
||||
self.ddim_discretize = opts.ddim_discretize
|
||||
self.s_churn = opts.s_churn
|
||||
self.s_tmin = opts.s_tmin
|
||||
|
@ -130,7 +130,7 @@ class Processed:
|
|||
self.s_tmin = p.s_tmin
|
||||
self.s_tmax = p.s_tmax
|
||||
self.s_noise = p.s_noise
|
||||
|
||||
self.sampler_noise_scheduler_override = p.sampler_noise_scheduler_override
|
||||
self.prompt = self.prompt if type(self.prompt) != list else self.prompt[0]
|
||||
self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0]
|
||||
self.seed = int(self.seed if type(self.seed) != list else self.seed[0])
|
||||
|
@ -271,7 +271,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
|
|||
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
|
||||
"Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
|
||||
"Denoising strength": getattr(p, 'denoising_strength', None),
|
||||
"Eta": (None if p.sampler.eta == p.sampler.default_eta else p.sampler.eta),
|
||||
"Eta": (None if p.sampler is None or p.sampler.eta == p.sampler.default_eta else p.sampler.eta),
|
||||
}
|
||||
|
||||
generation_params.update(p.extra_generation_params)
|
||||
|
@ -295,8 +295,11 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
|||
|
||||
fix_seed(p)
|
||||
|
||||
os.makedirs(p.outpath_samples, exist_ok=True)
|
||||
os.makedirs(p.outpath_grids, exist_ok=True)
|
||||
if p.outpath_samples is not None:
|
||||
os.makedirs(p.outpath_samples, exist_ok=True)
|
||||
|
||||
if p.outpath_grids is not None:
|
||||
os.makedirs(p.outpath_grids, exist_ok=True)
|
||||
|
||||
modules.sd_hijack.model_hijack.apply_circular(p.tiling)
|
||||
|
||||
|
@ -323,7 +326,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
|||
return create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration, position_in_batch)
|
||||
|
||||
if os.path.exists(cmd_opts.embeddings_dir):
|
||||
model_hijack.load_textual_inversion_embeddings(cmd_opts.embeddings_dir, p.sd_model)
|
||||
model_hijack.embedding_db.load_textual_inversion_embeddings()
|
||||
|
||||
infotexts = []
|
||||
output_images = []
|
||||
|
@ -492,8 +495,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
|
||||
x_sample = x_sample.astype(np.uint8)
|
||||
image = Image.fromarray(x_sample)
|
||||
upscaler = [x for x in shared.sd_upscalers if x.name == opts.upscaler_for_img2img][0]
|
||||
image = upscaler.upscale(image, self.width, self.height)
|
||||
image = images.resize_image(0, image, self.width, self.height)
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
image = np.moveaxis(image, 2, 0)
|
||||
batch_images.append(image)
|
||||
|
|
|
@ -1,119 +1,135 @@
|
|||
import os
|
||||
import sys
|
||||
import traceback
|
||||
from collections import namedtuple
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from basicsr.utils.download_util import load_file_from_url
|
||||
from realesrgan import RealESRGANer
|
||||
|
||||
import modules.images
|
||||
from modules.upscaler import Upscaler, UpscalerData
|
||||
from modules.paths import models_path
|
||||
from modules.shared import cmd_opts, opts
|
||||
|
||||
RealesrganModelInfo = namedtuple("RealesrganModelInfo", ["name", "location", "model", "netscale"])
|
||||
realesrgan_models = []
|
||||
have_realesrgan = False
|
||||
|
||||
class UpscalerRealESRGAN(Upscaler):
|
||||
def __init__(self, path):
|
||||
self.name = "RealESRGAN"
|
||||
self.model_path = os.path.join(models_path, self.name)
|
||||
self.user_path = path
|
||||
super().__init__()
|
||||
try:
|
||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||
from realesrgan import RealESRGANer
|
||||
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
|
||||
self.enable = True
|
||||
self.scalers = []
|
||||
scalers = self.load_models(path)
|
||||
for scaler in scalers:
|
||||
if scaler.name in opts.realesrgan_enabled_models:
|
||||
self.scalers.append(scaler)
|
||||
|
||||
except Exception:
|
||||
print("Error importing Real-ESRGAN:", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
self.enable = False
|
||||
self.scalers = []
|
||||
|
||||
def do_upscale(self, img, path):
|
||||
if not self.enable:
|
||||
return img
|
||||
|
||||
info = self.load_model(path)
|
||||
if not os.path.exists(info.data_path):
|
||||
print("Unable to load RealESRGAN model: %s" % info.name)
|
||||
return img
|
||||
|
||||
upsampler = RealESRGANer(
|
||||
scale=info.scale,
|
||||
model_path=info.data_path,
|
||||
model=info.model(),
|
||||
half=not cmd_opts.no_half,
|
||||
tile=opts.ESRGAN_tile,
|
||||
tile_pad=opts.ESRGAN_tile_overlap,
|
||||
)
|
||||
|
||||
upsampled = upsampler.enhance(np.array(img), outscale=info.scale)[0]
|
||||
|
||||
image = Image.fromarray(upsampled)
|
||||
return image
|
||||
|
||||
def load_model(self, path):
|
||||
try:
|
||||
info = None
|
||||
for scaler in self.scalers:
|
||||
if scaler.data_path == path:
|
||||
info = scaler
|
||||
|
||||
if info is None:
|
||||
print(f"Unable to find model info: {path}")
|
||||
return None
|
||||
|
||||
model_file = load_file_from_url(url=info.data_path, model_dir=self.model_path, progress=True)
|
||||
info.data_path = model_file
|
||||
return info
|
||||
except Exception as e:
|
||||
print(f"Error making Real-ESRGAN models list: {e}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
return None
|
||||
|
||||
def load_models(self, _):
|
||||
return get_realesrgan_models(self)
|
||||
|
||||
|
||||
def get_realesrgan_models():
|
||||
def get_realesrgan_models(scaler):
|
||||
try:
|
||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||
from realesrgan import RealESRGANer
|
||||
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
|
||||
models = [
|
||||
RealesrganModelInfo(
|
||||
name="Real-ESRGAN General x4x3",
|
||||
location="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth",
|
||||
netscale=4,
|
||||
UpscalerData(
|
||||
name="R-ESRGAN General 4xV3",
|
||||
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth",
|
||||
scale=4,
|
||||
upscaler=scaler,
|
||||
model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
|
||||
),
|
||||
RealesrganModelInfo(
|
||||
name="Real-ESRGAN General WDN x4x3",
|
||||
location="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth",
|
||||
netscale=4,
|
||||
UpscalerData(
|
||||
name="R-ESRGAN General WDN 4xV3",
|
||||
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth",
|
||||
scale=4,
|
||||
upscaler=scaler,
|
||||
model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
|
||||
),
|
||||
RealesrganModelInfo(
|
||||
name="Real-ESRGAN AnimeVideo",
|
||||
location="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth",
|
||||
netscale=4,
|
||||
UpscalerData(
|
||||
name="R-ESRGAN AnimeVideo",
|
||||
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth",
|
||||
scale=4,
|
||||
upscaler=scaler,
|
||||
model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu')
|
||||
),
|
||||
RealesrganModelInfo(
|
||||
name="Real-ESRGAN 4x plus",
|
||||
location="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
|
||||
netscale=4,
|
||||
UpscalerData(
|
||||
name="R-ESRGAN 4x+",
|
||||
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
|
||||
scale=4,
|
||||
upscaler=scaler,
|
||||
model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
|
||||
),
|
||||
RealesrganModelInfo(
|
||||
name="Real-ESRGAN 4x plus anime 6B",
|
||||
location="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
|
||||
netscale=4,
|
||||
UpscalerData(
|
||||
name="R-ESRGAN 4x+ Anime6B",
|
||||
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
|
||||
scale=4,
|
||||
upscaler=scaler,
|
||||
model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
|
||||
),
|
||||
RealesrganModelInfo(
|
||||
name="Real-ESRGAN 2x plus",
|
||||
location="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
|
||||
netscale=2,
|
||||
UpscalerData(
|
||||
name="R-ESRGAN 2x+",
|
||||
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
|
||||
scale=2,
|
||||
upscaler=scaler,
|
||||
model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
|
||||
),
|
||||
]
|
||||
return models
|
||||
except Exception as e:
|
||||
print("Error makeing Real-ESRGAN midels list:", file=sys.stderr)
|
||||
print("Error making Real-ESRGAN models list:", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
|
||||
class UpscalerRealESRGAN(modules.images.Upscaler):
|
||||
def __init__(self, upscaling, model_index):
|
||||
self.upscaling = upscaling
|
||||
self.model_index = model_index
|
||||
self.name = realesrgan_models[model_index].name
|
||||
|
||||
def do_upscale(self, img):
|
||||
return upscale_with_realesrgan(img, self.upscaling, self.model_index)
|
||||
|
||||
|
||||
def setup_realesrgan():
|
||||
global realesrgan_models
|
||||
global have_realesrgan
|
||||
|
||||
try:
|
||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||
from realesrgan import RealESRGANer
|
||||
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
|
||||
|
||||
realesrgan_models = get_realesrgan_models()
|
||||
have_realesrgan = True
|
||||
|
||||
for i, model in enumerate(realesrgan_models):
|
||||
if model.name in opts.realesrgan_enabled_models:
|
||||
modules.shared.sd_upscalers.append(UpscalerRealESRGAN(model.netscale, i))
|
||||
|
||||
except Exception:
|
||||
print("Error importing Real-ESRGAN:", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
realesrgan_models = [RealesrganModelInfo('None', '', 0, None)]
|
||||
have_realesrgan = False
|
||||
|
||||
|
||||
def upscale_with_realesrgan(image, RealESRGAN_upscaling, RealESRGAN_model_index):
|
||||
if not have_realesrgan:
|
||||
return image
|
||||
|
||||
info = realesrgan_models[RealESRGAN_model_index]
|
||||
|
||||
model = info.model()
|
||||
upsampler = RealESRGANer(
|
||||
scale=info.netscale,
|
||||
model_path=info.location,
|
||||
model=model,
|
||||
half=not cmd_opts.no_half,
|
||||
tile=opts.ESRGAN_tile,
|
||||
tile_pad=opts.ESRGAN_tile_overlap,
|
||||
)
|
||||
|
||||
upsampled = upsampler.enhance(np.array(image), outscale=RealESRGAN_upscaling)[0]
|
||||
|
||||
image = Image.fromarray(upsampled)
|
||||
return image
|
||||
|
|
|
@ -6,253 +6,51 @@ import torch
|
|||
import numpy as np
|
||||
from torch import einsum
|
||||
|
||||
from modules import prompt_parser
|
||||
import modules.textual_inversion.textual_inversion
|
||||
from modules import prompt_parser, devices, sd_hijack_optimizations, shared
|
||||
from modules.shared import opts, device, cmd_opts
|
||||
|
||||
from ldm.util import default
|
||||
from einops import rearrange
|
||||
import ldm.modules.attention
|
||||
import ldm.modules.diffusionmodules.model
|
||||
|
||||
|
||||
# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
|
||||
def split_cross_attention_forward_v1(self, x, context=None, mask=None):
|
||||
h = self.heads
|
||||
|
||||
q = self.to_q(x)
|
||||
context = default(context, x)
|
||||
k = self.to_k(context)
|
||||
v = self.to_v(context)
|
||||
del context, x
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
||||
|
||||
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
|
||||
for i in range(0, q.shape[0], 2):
|
||||
end = i + 2
|
||||
s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
|
||||
s1 *= self.scale
|
||||
|
||||
s2 = s1.softmax(dim=-1)
|
||||
del s1
|
||||
|
||||
r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
|
||||
del s2
|
||||
|
||||
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
|
||||
del r1
|
||||
|
||||
return self.to_out(r2)
|
||||
attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
|
||||
diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
|
||||
diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
|
||||
|
||||
|
||||
# taken from https://github.com/Doggettx/stable-diffusion
|
||||
def split_cross_attention_forward(self, x, context=None, mask=None):
|
||||
h = self.heads
|
||||
def apply_optimizations():
|
||||
if cmd_opts.opt_split_attention_v1:
|
||||
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
|
||||
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()):
|
||||
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward
|
||||
ldm.modules.diffusionmodules.model.nonlinearity = sd_hijack_optimizations.nonlinearity_hijack
|
||||
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward
|
||||
|
||||
q_in = self.to_q(x)
|
||||
context = default(context, x)
|
||||
k_in = self.to_k(context) * self.scale
|
||||
v_in = self.to_v(context)
|
||||
del context, x
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
|
||||
del q_in, k_in, v_in
|
||||
def undo_optimizations():
|
||||
ldm.modules.attention.CrossAttention.forward = attention_CrossAttention_forward
|
||||
ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity
|
||||
ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
|
||||
|
||||
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||
|
||||
stats = torch.cuda.memory_stats(q.device)
|
||||
mem_active = stats['active_bytes.all.current']
|
||||
mem_reserved = stats['reserved_bytes.all.current']
|
||||
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
|
||||
mem_free_torch = mem_reserved - mem_active
|
||||
mem_free_total = mem_free_cuda + mem_free_torch
|
||||
|
||||
gb = 1024 ** 3
|
||||
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
|
||||
modifier = 3 if q.element_size() == 2 else 2.5
|
||||
mem_required = tensor_size * modifier
|
||||
steps = 1
|
||||
|
||||
if mem_required > mem_free_total:
|
||||
steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
|
||||
# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
|
||||
# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
|
||||
|
||||
if steps > 64:
|
||||
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
|
||||
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
|
||||
f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
|
||||
|
||||
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
||||
for i in range(0, q.shape[1], slice_size):
|
||||
end = i + slice_size
|
||||
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
|
||||
|
||||
s2 = s1.softmax(dim=-1, dtype=q.dtype)
|
||||
del s1
|
||||
|
||||
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
|
||||
del s2
|
||||
|
||||
del q, k, v
|
||||
|
||||
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
|
||||
del r1
|
||||
|
||||
return self.to_out(r2)
|
||||
|
||||
def nonlinearity_hijack(x):
|
||||
# swish
|
||||
t = torch.sigmoid(x)
|
||||
x *= t
|
||||
del t
|
||||
|
||||
return x
|
||||
|
||||
def cross_attention_attnblock_forward(self, x):
|
||||
h_ = x
|
||||
h_ = self.norm(h_)
|
||||
q1 = self.q(h_)
|
||||
k1 = self.k(h_)
|
||||
v = self.v(h_)
|
||||
|
||||
# compute attention
|
||||
b, c, h, w = q1.shape
|
||||
|
||||
q2 = q1.reshape(b, c, h*w)
|
||||
del q1
|
||||
|
||||
q = q2.permute(0, 2, 1) # b,hw,c
|
||||
del q2
|
||||
|
||||
k = k1.reshape(b, c, h*w) # b,c,hw
|
||||
del k1
|
||||
|
||||
h_ = torch.zeros_like(k, device=q.device)
|
||||
|
||||
stats = torch.cuda.memory_stats(q.device)
|
||||
mem_active = stats['active_bytes.all.current']
|
||||
mem_reserved = stats['reserved_bytes.all.current']
|
||||
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
|
||||
mem_free_torch = mem_reserved - mem_active
|
||||
mem_free_total = mem_free_cuda + mem_free_torch
|
||||
|
||||
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
|
||||
mem_required = tensor_size * 2.5
|
||||
steps = 1
|
||||
|
||||
if mem_required > mem_free_total:
|
||||
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
|
||||
|
||||
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
||||
for i in range(0, q.shape[1], slice_size):
|
||||
end = i + slice_size
|
||||
|
||||
w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
||||
w2 = w1 * (int(c)**(-0.5))
|
||||
del w1
|
||||
w3 = torch.nn.functional.softmax(w2, dim=2, dtype=q.dtype)
|
||||
del w2
|
||||
|
||||
# attend to values
|
||||
v1 = v.reshape(b, c, h*w)
|
||||
w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
|
||||
del w3
|
||||
|
||||
h_[:, :, i:end] = torch.bmm(v1, w4) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
||||
del v1, w4
|
||||
|
||||
h2 = h_.reshape(b, c, h, w)
|
||||
del h_
|
||||
|
||||
h3 = self.proj_out(h2)
|
||||
del h2
|
||||
|
||||
h3 += x
|
||||
|
||||
return h3
|
||||
|
||||
class StableDiffusionModelHijack:
|
||||
ids_lookup = {}
|
||||
word_embeddings = {}
|
||||
word_embeddings_checksums = {}
|
||||
fixes = None
|
||||
comments = []
|
||||
dir_mtime = None
|
||||
layers = None
|
||||
circular_enabled = False
|
||||
clip = None
|
||||
|
||||
def load_textual_inversion_embeddings(self, dirname, model):
|
||||
mt = os.path.getmtime(dirname)
|
||||
if self.dir_mtime is not None and mt <= self.dir_mtime:
|
||||
return
|
||||
|
||||
self.dir_mtime = mt
|
||||
self.ids_lookup.clear()
|
||||
self.word_embeddings.clear()
|
||||
|
||||
tokenizer = model.cond_stage_model.tokenizer
|
||||
|
||||
def const_hash(a):
|
||||
r = 0
|
||||
for v in a:
|
||||
r = (r * 281 ^ int(v) * 997) & 0xFFFFFFFF
|
||||
return r
|
||||
|
||||
def process_file(path, filename):
|
||||
name = os.path.splitext(filename)[0]
|
||||
|
||||
data = torch.load(path, map_location="cpu")
|
||||
|
||||
# textual inversion embeddings
|
||||
if 'string_to_param' in data:
|
||||
param_dict = data['string_to_param']
|
||||
if hasattr(param_dict, '_parameters'):
|
||||
param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11
|
||||
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
|
||||
emb = next(iter(param_dict.items()))[1]
|
||||
# diffuser concepts
|
||||
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:
|
||||
assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
|
||||
|
||||
emb = next(iter(data.values()))
|
||||
if len(emb.shape) == 1:
|
||||
emb = emb.unsqueeze(0)
|
||||
|
||||
self.word_embeddings[name] = emb.detach().to(device)
|
||||
self.word_embeddings_checksums[name] = f'{const_hash(emb.reshape(-1)*100)&0xffff:04x}'
|
||||
|
||||
ids = tokenizer([name], add_special_tokens=False)['input_ids'][0]
|
||||
|
||||
first_id = ids[0]
|
||||
if first_id not in self.ids_lookup:
|
||||
self.ids_lookup[first_id] = []
|
||||
self.ids_lookup[first_id].append((ids, name))
|
||||
|
||||
for fn in os.listdir(dirname):
|
||||
try:
|
||||
process_file(os.path.join(dirname, fn), fn)
|
||||
except Exception:
|
||||
print(f"Error loading emedding {fn}:", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
continue
|
||||
|
||||
print(f"Loaded a total of {len(self.word_embeddings)} textual inversion embeddings.")
|
||||
embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase(cmd_opts.embeddings_dir)
|
||||
|
||||
def hijack(self, m):
|
||||
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
|
||||
|
||||
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
|
||||
m.cond_stage_model = FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
|
||||
|
||||
self.clip = m.cond_stage_model
|
||||
|
||||
if cmd_opts.opt_split_attention_v1:
|
||||
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1
|
||||
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()):
|
||||
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward
|
||||
ldm.modules.diffusionmodules.model.nonlinearity = nonlinearity_hijack
|
||||
ldm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward
|
||||
apply_optimizations()
|
||||
|
||||
def flatten(el):
|
||||
flattened = [flatten(children) for children in el.children()]
|
||||
|
@ -263,6 +61,14 @@ class StableDiffusionModelHijack:
|
|||
|
||||
self.layers = flatten(m)
|
||||
|
||||
def undo_hijack(self, m):
|
||||
if type(m.cond_stage_model) == FrozenCLIPEmbedderWithCustomWords:
|
||||
m.cond_stage_model = m.cond_stage_model.wrapped
|
||||
|
||||
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
|
||||
if type(model_embeddings.token_embedding) == EmbeddingsWithFixes:
|
||||
model_embeddings.token_embedding = model_embeddings.token_embedding.wrapped
|
||||
|
||||
def apply_circular(self, enable):
|
||||
if self.circular_enabled == enable:
|
||||
return
|
||||
|
@ -282,7 +88,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
|||
def __init__(self, wrapped, hijack):
|
||||
super().__init__()
|
||||
self.wrapped = wrapped
|
||||
self.hijack = hijack
|
||||
self.hijack: StableDiffusionModelHijack = hijack
|
||||
self.tokenizer = wrapped.tokenizer
|
||||
self.max_length = wrapped.max_length
|
||||
self.token_mults = {}
|
||||
|
@ -303,7 +109,6 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
|||
if mult != 1.0:
|
||||
self.token_mults[ident] = mult
|
||||
|
||||
|
||||
def tokenize_line(self, line, used_custom_terms, hijack_comments):
|
||||
id_start = self.wrapped.tokenizer.bos_token_id
|
||||
id_end = self.wrapped.tokenizer.eos_token_id
|
||||
|
@ -325,28 +130,19 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
|||
while i < len(tokens):
|
||||
token = tokens[i]
|
||||
|
||||
possible_matches = self.hijack.ids_lookup.get(token, None)
|
||||
embedding = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
|
||||
|
||||
if possible_matches is None:
|
||||
if embedding is None:
|
||||
remade_tokens.append(token)
|
||||
multipliers.append(weight)
|
||||
i += 1
|
||||
else:
|
||||
found = False
|
||||
for ids, word in possible_matches:
|
||||
if tokens[i:i + len(ids)] == ids:
|
||||
emb_len = int(self.hijack.word_embeddings[word].shape[0])
|
||||
fixes.append((len(remade_tokens), word))
|
||||
remade_tokens += [0] * emb_len
|
||||
multipliers += [weight] * emb_len
|
||||
i += len(ids) - 1
|
||||
found = True
|
||||
used_custom_terms.append((word, self.hijack.word_embeddings_checksums[word]))
|
||||
break
|
||||
|
||||
if not found:
|
||||
remade_tokens.append(token)
|
||||
multipliers.append(weight)
|
||||
i += 1
|
||||
emb_len = int(embedding.vec.shape[0])
|
||||
fixes.append((len(remade_tokens), embedding))
|
||||
remade_tokens += [0] * emb_len
|
||||
multipliers += [weight] * emb_len
|
||||
used_custom_terms.append((embedding.name, embedding.checksum()))
|
||||
i += emb_len
|
||||
|
||||
if len(remade_tokens) > maxlen - 2:
|
||||
vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
|
||||
|
@ -417,32 +213,23 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
|||
while i < len(tokens):
|
||||
token = tokens[i]
|
||||
|
||||
possible_matches = self.hijack.ids_lookup.get(token, None)
|
||||
embedding = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
|
||||
|
||||
mult_change = self.token_mults.get(token) if opts.enable_emphasis else None
|
||||
if mult_change is not None:
|
||||
mult *= mult_change
|
||||
elif possible_matches is None:
|
||||
i += 1
|
||||
elif embedding is None:
|
||||
remade_tokens.append(token)
|
||||
multipliers.append(mult)
|
||||
i += 1
|
||||
else:
|
||||
found = False
|
||||
for ids, word in possible_matches:
|
||||
if tokens[i:i+len(ids)] == ids:
|
||||
emb_len = int(self.hijack.word_embeddings[word].shape[0])
|
||||
fixes.append((len(remade_tokens), word))
|
||||
remade_tokens += [0] * emb_len
|
||||
multipliers += [mult] * emb_len
|
||||
i += len(ids) - 1
|
||||
found = True
|
||||
used_custom_terms.append((word, self.hijack.word_embeddings_checksums[word]))
|
||||
break
|
||||
|
||||
if not found:
|
||||
remade_tokens.append(token)
|
||||
multipliers.append(mult)
|
||||
|
||||
i += 1
|
||||
emb_len = int(embedding.vec.shape[0])
|
||||
fixes.append((len(remade_tokens), embedding))
|
||||
remade_tokens += [0] * emb_len
|
||||
multipliers += [mult] * emb_len
|
||||
used_custom_terms.append((embedding.name, embedding.checksum()))
|
||||
i += emb_len
|
||||
|
||||
if len(remade_tokens) > maxlen - 2:
|
||||
vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
|
||||
|
@ -450,6 +237,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
|||
overflowing_words = [vocab.get(int(x), "") for x in ovf]
|
||||
overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
|
||||
hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
|
||||
|
||||
token_count = len(remade_tokens)
|
||||
remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
|
||||
remade_tokens = [id_start] + remade_tokens[0:maxlen-2] + [id_end]
|
||||
|
@ -470,7 +258,6 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
|||
else:
|
||||
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text)
|
||||
|
||||
|
||||
self.hijack.fixes = hijack_fixes
|
||||
self.hijack.comments = hijack_comments
|
||||
|
||||
|
@ -503,14 +290,19 @@ class EmbeddingsWithFixes(torch.nn.Module):
|
|||
|
||||
inputs_embeds = self.wrapped(input_ids)
|
||||
|
||||
if batch_fixes is not None:
|
||||
for fixes, tensor in zip(batch_fixes, inputs_embeds):
|
||||
for offset, word in fixes:
|
||||
emb = self.embeddings.word_embeddings[word]
|
||||
emb_len = min(tensor.shape[0]-offset-1, emb.shape[0])
|
||||
tensor[offset+1:offset+1+emb_len] = self.embeddings.word_embeddings[word][0:emb_len]
|
||||
if batch_fixes is None or len(batch_fixes) == 0 or max([len(x) for x in batch_fixes]) == 0:
|
||||
return inputs_embeds
|
||||
|
||||
return inputs_embeds
|
||||
vecs = []
|
||||
for fixes, tensor in zip(batch_fixes, inputs_embeds):
|
||||
for offset, embedding in fixes:
|
||||
emb = embedding.vec
|
||||
emb_len = min(tensor.shape[0]-offset-1, emb.shape[0])
|
||||
tensor = torch.cat([tensor[0:offset+1], emb[0:emb_len], tensor[offset+1+emb_len:]])
|
||||
|
||||
vecs.append(tensor)
|
||||
|
||||
return torch.stack(vecs)
|
||||
|
||||
|
||||
def add_circular_option_to_conv_2d():
|
||||
|
|
164
modules/sd_hijack_optimizations.py
Normal file
164
modules/sd_hijack_optimizations.py
Normal file
|
@ -0,0 +1,164 @@
|
|||
import math
|
||||
import torch
|
||||
from torch import einsum
|
||||
|
||||
from ldm.util import default
|
||||
from einops import rearrange
|
||||
|
||||
|
||||
# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
|
||||
def split_cross_attention_forward_v1(self, x, context=None, mask=None):
|
||||
h = self.heads
|
||||
|
||||
q = self.to_q(x)
|
||||
context = default(context, x)
|
||||
k = self.to_k(context)
|
||||
v = self.to_v(context)
|
||||
del context, x
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
||||
|
||||
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
|
||||
for i in range(0, q.shape[0], 2):
|
||||
end = i + 2
|
||||
s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
|
||||
s1 *= self.scale
|
||||
|
||||
s2 = s1.softmax(dim=-1)
|
||||
del s1
|
||||
|
||||
r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
|
||||
del s2
|
||||
|
||||
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
|
||||
del r1
|
||||
|
||||
return self.to_out(r2)
|
||||
|
||||
|
||||
# taken from https://github.com/Doggettx/stable-diffusion
|
||||
def split_cross_attention_forward(self, x, context=None, mask=None):
|
||||
h = self.heads
|
||||
|
||||
q_in = self.to_q(x)
|
||||
context = default(context, x)
|
||||
k_in = self.to_k(context) * self.scale
|
||||
v_in = self.to_v(context)
|
||||
del context, x
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
|
||||
del q_in, k_in, v_in
|
||||
|
||||
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||
|
||||
stats = torch.cuda.memory_stats(q.device)
|
||||
mem_active = stats['active_bytes.all.current']
|
||||
mem_reserved = stats['reserved_bytes.all.current']
|
||||
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
|
||||
mem_free_torch = mem_reserved - mem_active
|
||||
mem_free_total = mem_free_cuda + mem_free_torch
|
||||
|
||||
gb = 1024 ** 3
|
||||
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
|
||||
modifier = 3 if q.element_size() == 2 else 2.5
|
||||
mem_required = tensor_size * modifier
|
||||
steps = 1
|
||||
|
||||
if mem_required > mem_free_total:
|
||||
steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
|
||||
# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
|
||||
# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
|
||||
|
||||
if steps > 64:
|
||||
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
|
||||
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
|
||||
f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
|
||||
|
||||
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
||||
for i in range(0, q.shape[1], slice_size):
|
||||
end = i + slice_size
|
||||
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
|
||||
|
||||
s2 = s1.softmax(dim=-1, dtype=q.dtype)
|
||||
del s1
|
||||
|
||||
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
|
||||
del s2
|
||||
|
||||
del q, k, v
|
||||
|
||||
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
|
||||
del r1
|
||||
|
||||
return self.to_out(r2)
|
||||
|
||||
def nonlinearity_hijack(x):
|
||||
# swish
|
||||
t = torch.sigmoid(x)
|
||||
x *= t
|
||||
del t
|
||||
|
||||
return x
|
||||
|
||||
def cross_attention_attnblock_forward(self, x):
|
||||
h_ = x
|
||||
h_ = self.norm(h_)
|
||||
q1 = self.q(h_)
|
||||
k1 = self.k(h_)
|
||||
v = self.v(h_)
|
||||
|
||||
# compute attention
|
||||
b, c, h, w = q1.shape
|
||||
|
||||
q2 = q1.reshape(b, c, h*w)
|
||||
del q1
|
||||
|
||||
q = q2.permute(0, 2, 1) # b,hw,c
|
||||
del q2
|
||||
|
||||
k = k1.reshape(b, c, h*w) # b,c,hw
|
||||
del k1
|
||||
|
||||
h_ = torch.zeros_like(k, device=q.device)
|
||||
|
||||
stats = torch.cuda.memory_stats(q.device)
|
||||
mem_active = stats['active_bytes.all.current']
|
||||
mem_reserved = stats['reserved_bytes.all.current']
|
||||
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
|
||||
mem_free_torch = mem_reserved - mem_active
|
||||
mem_free_total = mem_free_cuda + mem_free_torch
|
||||
|
||||
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
|
||||
mem_required = tensor_size * 2.5
|
||||
steps = 1
|
||||
|
||||
if mem_required > mem_free_total:
|
||||
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
|
||||
|
||||
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
||||
for i in range(0, q.shape[1], slice_size):
|
||||
end = i + slice_size
|
||||
|
||||
w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
||||
w2 = w1 * (int(c)**(-0.5))
|
||||
del w1
|
||||
w3 = torch.nn.functional.softmax(w2, dim=2, dtype=q.dtype)
|
||||
del w2
|
||||
|
||||
# attend to values
|
||||
v1 = v.reshape(b, c, h*w)
|
||||
w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
|
||||
del w3
|
||||
|
||||
h_[:, :, i:end] = torch.bmm(v1, w4) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
||||
del v1, w4
|
||||
|
||||
h2 = h_.reshape(b, c, h, w)
|
||||
del h_
|
||||
|
||||
h3 = self.proj_out(h2)
|
||||
del h2
|
||||
|
||||
h3 += x
|
||||
|
||||
return h3
|
|
@ -8,7 +8,14 @@ from omegaconf import OmegaConf
|
|||
|
||||
from ldm.util import instantiate_from_config
|
||||
|
||||
from modules import shared
|
||||
from modules import shared, modelloader, devices
|
||||
from modules.paths import models_path
|
||||
|
||||
model_dir = "Stable-diffusion"
|
||||
model_path = os.path.abspath(os.path.join(models_path, model_dir))
|
||||
model_name = "sd-v1-4.ckpt"
|
||||
model_url = "https://drive.yerf.org/wl/?id=EBfTrmcCCUAGaQBXVIj5lJmEhjoP1tgl&mode=grid&download=1"
|
||||
user_dir = None
|
||||
|
||||
CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name'])
|
||||
checkpoints_list = {}
|
||||
|
@ -23,20 +30,30 @@ except Exception:
|
|||
pass
|
||||
|
||||
|
||||
def setup_model(dirname):
|
||||
global user_dir
|
||||
user_dir = dirname
|
||||
if not os.path.exists(model_path):
|
||||
os.makedirs(model_path)
|
||||
checkpoints_list.clear()
|
||||
list_models()
|
||||
|
||||
|
||||
def checkpoint_tiles():
|
||||
return sorted([x.title for x in checkpoints_list.values()])
|
||||
|
||||
|
||||
def list_models():
|
||||
checkpoints_list.clear()
|
||||
model_list = modelloader.load_models(model_path=model_path, model_url=model_url, command_path=user_dir, ext_filter=[".ckpt"], download_name=model_name)
|
||||
|
||||
model_dir = os.path.abspath(shared.cmd_opts.ckpt_dir)
|
||||
|
||||
def modeltitle(path, h):
|
||||
def modeltitle(path, shorthash):
|
||||
abspath = os.path.abspath(path)
|
||||
|
||||
if abspath.startswith(model_dir):
|
||||
name = abspath.replace(model_dir, '')
|
||||
if user_dir is not None and abspath.startswith(user_dir):
|
||||
name = abspath.replace(user_dir, '')
|
||||
elif abspath.startswith(model_path):
|
||||
name = abspath.replace(model_path, '')
|
||||
else:
|
||||
name = os.path.basename(path)
|
||||
|
||||
|
@ -45,21 +62,27 @@ def list_models():
|
|||
|
||||
shortname = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]
|
||||
|
||||
return f'{name} [{h}]', shortname
|
||||
return f'{name} [{shorthash}]', shortname
|
||||
|
||||
cmd_ckpt = shared.cmd_opts.ckpt
|
||||
if os.path.exists(cmd_ckpt):
|
||||
h = model_hash(cmd_ckpt)
|
||||
title, model_name = modeltitle(cmd_ckpt, h)
|
||||
checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, model_name)
|
||||
title, short_model_name = modeltitle(cmd_ckpt, h)
|
||||
checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name)
|
||||
shared.opts.sd_model_checkpoint = title
|
||||
elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
|
||||
print(f"Checkpoint in --ckpt argument not found: {cmd_ckpt}", file=sys.stderr)
|
||||
print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr)
|
||||
for filename in model_list:
|
||||
h = model_hash(filename)
|
||||
title, short_model_name = modeltitle(filename, h)
|
||||
checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name)
|
||||
|
||||
if os.path.exists(model_dir):
|
||||
for filename in glob.glob(model_dir + '/**/*.ckpt', recursive=True):
|
||||
h = model_hash(filename)
|
||||
title, model_name = modeltitle(filename, h)
|
||||
checkpoints_list[title] = CheckpointInfo(filename, title, h, model_name)
|
||||
|
||||
def get_closet_checkpoint_match(searchString):
|
||||
applicable = sorted([info for info in checkpoints_list.values() if searchString in info.title], key = lambda x:len(x.title))
|
||||
if len(applicable) > 0:
|
||||
return applicable[0]
|
||||
return None
|
||||
|
||||
|
||||
def model_hash(filename):
|
||||
|
@ -111,6 +134,8 @@ def load_model_weights(model, checkpoint_file, sd_model_hash):
|
|||
if not shared.cmd_opts.no_half:
|
||||
model.half()
|
||||
|
||||
devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16
|
||||
|
||||
model.sd_model_hash = sd_model_hash
|
||||
model.sd_model_checkpint = checkpoint_file
|
||||
|
||||
|
@ -137,7 +162,7 @@ def load_model():
|
|||
|
||||
|
||||
def reload_model_weights(sd_model, info=None):
|
||||
from modules import lowvram, devices
|
||||
from modules import lowvram, devices, sd_hijack
|
||||
checkpoint_info = info or select_checkpoint()
|
||||
|
||||
if sd_model.sd_model_checkpint == checkpoint_info.filename:
|
||||
|
@ -148,8 +173,12 @@ def reload_model_weights(sd_model, info=None):
|
|||
else:
|
||||
sd_model.to(devices.cpu)
|
||||
|
||||
sd_hijack.model_hijack.undo_hijack(sd_model)
|
||||
|
||||
load_model_weights(sd_model, checkpoint_info.filename, checkpoint_info.hash)
|
||||
|
||||
sd_hijack.model_hijack.hijack(sd_model)
|
||||
|
||||
if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
|
||||
sd_model.to(devices.device)
|
||||
|
||||
|
|
|
@ -4,7 +4,6 @@ import torch
|
|||
import tqdm
|
||||
from PIL import Image
|
||||
import inspect
|
||||
|
||||
import k_diffusion.sampling
|
||||
import ldm.models.diffusion.ddim
|
||||
import ldm.models.diffusion.plms
|
||||
|
@ -23,6 +22,8 @@ samplers_k_diffusion = [
|
|||
('Heun', 'sample_heun', ['k_heun']),
|
||||
('DPM2', 'sample_dpm_2', ['k_dpm_2']),
|
||||
('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a']),
|
||||
('DPM fast', 'sample_dpm_fast', ['k_dpm_fast']),
|
||||
('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad']),
|
||||
]
|
||||
|
||||
samplers_data_k_diffusion = [
|
||||
|
@ -36,7 +37,7 @@ samplers = [
|
|||
SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), []),
|
||||
SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), []),
|
||||
]
|
||||
samplers_for_img2img = [x for x in samplers if x.name != 'PLMS']
|
||||
samplers_for_img2img = [x for x in samplers if x.name not in ['PLMS', 'DPM fast', 'DPM adaptive']]
|
||||
|
||||
sampler_extra_params = {
|
||||
'sample_euler': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
|
||||
|
@ -289,7 +290,10 @@ class KDiffusionSampler:
|
|||
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None):
|
||||
steps, t_enc = setup_img2img_steps(p, steps)
|
||||
|
||||
sigmas = self.model_wrap.get_sigmas(steps)
|
||||
if p.sampler_noise_scheduler_override:
|
||||
sigmas = p.sampler_noise_scheduler_override(steps)
|
||||
else:
|
||||
sigmas = self.model_wrap.get_sigmas(steps)
|
||||
|
||||
noise = noise * sigmas[steps - t_enc - 1]
|
||||
xi = x + noise
|
||||
|
@ -305,12 +309,20 @@ class KDiffusionSampler:
|
|||
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None):
|
||||
steps = steps or p.steps
|
||||
|
||||
sigmas = self.model_wrap.get_sigmas(steps)
|
||||
if p.sampler_noise_scheduler_override:
|
||||
sigmas = p.sampler_noise_scheduler_override(steps)
|
||||
else:
|
||||
sigmas = self.model_wrap.get_sigmas(steps)
|
||||
x = x * sigmas[0]
|
||||
|
||||
extra_params_kwargs = self.initialize(p)
|
||||
|
||||
samples = self.func(self.model_wrap_cfg, x, sigmas, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs)
|
||||
|
||||
if 'sigma_min' in inspect.signature(self.func).parameters:
|
||||
extra_params_kwargs['sigma_min'] = self.model_wrap.sigmas[0].item()
|
||||
extra_params_kwargs['sigma_max'] = self.model_wrap.sigmas[-1].item()
|
||||
if 'n' in inspect.signature(self.func).parameters:
|
||||
extra_params_kwargs['n'] = steps
|
||||
else:
|
||||
extra_params_kwargs['sigmas'] = sigmas
|
||||
samples = self.func(self.model_wrap_cfg, x, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs)
|
||||
return samples
|
||||
|
||||
|
|
|
@ -1,26 +1,27 @@
|
|||
import sys
|
||||
import argparse
|
||||
import datetime
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
import gradio as gr
|
||||
import tqdm
|
||||
import datetime
|
||||
|
||||
import modules.artists
|
||||
from modules.paths import script_path, sd_path
|
||||
from modules.devices import get_optimal_device
|
||||
import modules.styles
|
||||
import modules.interrogate
|
||||
import modules.memmon
|
||||
import modules.sd_models
|
||||
import modules.styles
|
||||
from modules.devices import get_optimal_device
|
||||
from modules.paths import script_path, sd_path
|
||||
|
||||
sd_model_file = os.path.join(script_path, 'model.ckpt')
|
||||
default_sd_model_file = sd_model_file
|
||||
|
||||
model_path = os.path.join(script_path, 'models')
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--config", type=str, default=os.path.join(sd_path, "configs/stable-diffusion/v1-inference.yaml"), help="path to config which constructs model",)
|
||||
parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; this checkpoint will be added to the list of checkpoints and loaded by default if you don't have a checkpoint selected in settings",)
|
||||
parser.add_argument("--ckpt-dir", type=str, default=os.path.join(script_path, 'models'), help="path to directory with stable diffusion checkpoints",)
|
||||
parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",)
|
||||
parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints")
|
||||
parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))
|
||||
parser.add_argument("--gfpgan-model", type=str, help="GFPGAN model file name", default=None)
|
||||
parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats")
|
||||
|
@ -34,8 +35,13 @@ parser.add_argument("--always-batch-cond-uncond", action='store_true', help="dis
|
|||
parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.")
|
||||
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
|
||||
parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site (doesn't work for me but you might have better luck)")
|
||||
parser.add_argument("--esrgan-models-path", type=str, help="path to directory with ESRGAN models", default=os.path.join(script_path, 'ESRGAN'))
|
||||
parser.add_argument("--swinir-models-path", type=str, help="path to directory with SwinIR models", default=os.path.join(script_path, 'SwinIR'))
|
||||
parser.add_argument("--codeformer-models-path", type=str, help="Path to directory with codeformer model file(s).", default=os.path.join(model_path, 'Codeformer'))
|
||||
parser.add_argument("--gfpgan-models-path", type=str, help="Path to directory with GFPGAN model file(s).", default=os.path.join(model_path, 'GFPGAN'))
|
||||
parser.add_argument("--esrgan-models-path", type=str, help="Path to directory with ESRGAN model file(s).", default=os.path.join(model_path, 'ESRGAN'))
|
||||
parser.add_argument("--bsrgan-models-path", type=str, help="Path to directory with BSRGAN model file(s).", default=os.path.join(model_path, 'BSRGAN'))
|
||||
parser.add_argument("--realesrgan-models-path", type=str, help="Path to directory with RealESRGAN model file(s).", default=os.path.join(model_path, 'RealESRGAN'))
|
||||
parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(model_path, 'SwinIR'))
|
||||
parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(model_path, 'LDSR'))
|
||||
parser.add_argument("--opt-split-attention", action='store_true', help="force-enables cross-attention layer optimization. By default, it's on for torch.cuda and off for other torch devices.")
|
||||
parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
|
||||
parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
|
||||
|
@ -53,7 +59,6 @@ parser.add_argument("--autolaunch", action='store_true', help="open the webui UR
|
|||
parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False)
|
||||
|
||||
cmd_opts = parser.parse_args()
|
||||
|
||||
device = get_optimal_device()
|
||||
|
||||
batch_cond_uncond = cmd_opts.always_batch_cond_uncond or not (cmd_opts.lowvram or cmd_opts.medvram)
|
||||
|
@ -61,6 +66,7 @@ parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram
|
|||
|
||||
config_filename = cmd_opts.ui_settings_file
|
||||
|
||||
|
||||
class State:
|
||||
interrupted = False
|
||||
job = ""
|
||||
|
@ -72,6 +78,7 @@ class State:
|
|||
current_latent = None
|
||||
current_image = None
|
||||
current_image_sampling_step = 0
|
||||
textinfo = None
|
||||
|
||||
def interrupt(self):
|
||||
self.interrupted = True
|
||||
|
@ -82,7 +89,7 @@ class State:
|
|||
self.current_image_sampling_step = 0
|
||||
|
||||
def get_job_timestamp(self):
|
||||
return datetime.datetime.now().strftime("%Y%m%d%H%M%S")
|
||||
return datetime.datetime.now().strftime("%Y%m%d%H%M%S") # shouldn't this return job_timestamp?
|
||||
|
||||
|
||||
state = State()
|
||||
|
@ -95,13 +102,13 @@ prompt_styles = modules.styles.StyleDatabase(styles_filename)
|
|||
interrogator = modules.interrogate.InterrogateModels("interrogate")
|
||||
|
||||
face_restorers = []
|
||||
|
||||
modules.sd_models.list_models()
|
||||
# This was moved to webui.py with the other model "setup" calls.
|
||||
# modules.sd_models.list_models()
|
||||
|
||||
|
||||
def realesrgan_models_names():
|
||||
import modules.realesrgan_model
|
||||
return [x.name for x in modules.realesrgan_model.get_realesrgan_models()]
|
||||
return [x.name for x in modules.realesrgan_model.get_realesrgan_models(None)]
|
||||
|
||||
|
||||
class OptionInfo:
|
||||
|
@ -167,13 +174,10 @@ options_templates.update(options_section(('saving-to-dirs', "Saving to a directo
|
|||
options_templates.update(options_section(('upscaling', "Upscaling"), {
|
||||
"ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers. 0 = no tiling.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}),
|
||||
"ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for ESRGAN upscalers. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
|
||||
"realesrgan_enabled_models": OptionInfo(["Real-ESRGAN 4x plus", "Real-ESRGAN 4x plus anime 6B"], "Select which RealESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": realesrgan_models_names()}),
|
||||
"realesrgan_enabled_models": OptionInfo(["R-ESRGAN x4+", "R-ESRGAN x4+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": realesrgan_models_names()}),
|
||||
"SWIN_tile": OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}),
|
||||
"SWIN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
|
||||
"ldsr_steps": OptionInfo(100, "LDSR processing steps. Lower = faster", gr.Slider, {"minimum": 1, "maximum": 200, "step": 1}),
|
||||
"ldsr_pre_down": OptionInfo(1, "LDSR Pre-process down-sample scale. 1 = no down-sampling, 4 = 1/4 scale.", gr.Slider, {"minimum": 1, "maximum": 4, "step": 1}),
|
||||
"ldsr_post_down": OptionInfo(1, "LDSR Post-process down-sample scale. 1 = no down-sampling, 4 = 1/4 scale.", gr.Slider, {"minimum": 1, "maximum": 4, "step": 1}),
|
||||
|
||||
"upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Radio, lambda: {"choices": [x.name for x in sd_upscalers]}),
|
||||
}))
|
||||
|
||||
|
@ -190,9 +194,9 @@ options_templates.update(options_section(('system', "System"), {
|
|||
}))
|
||||
|
||||
options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
||||
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Radio, lambda: {"choices": modules.sd_models.checkpoint_tiles()}),
|
||||
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}),
|
||||
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
|
||||
"save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
|
||||
"save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
|
||||
"img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."),
|
||||
"enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."),
|
||||
"enable_emphasis": OptionInfo(True, "Eemphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"),
|
||||
|
|
|
@ -53,6 +53,12 @@ class StyleDatabase:
|
|||
negative_prompt = row.get("negative_prompt", "")
|
||||
self.styles[row["name"]] = PromptStyle(row["name"], prompt, negative_prompt)
|
||||
|
||||
def get_style_prompts(self, styles):
|
||||
return [self.styles.get(x, self.no_style).prompt for x in styles]
|
||||
|
||||
def get_negative_style_prompts(self, styles):
|
||||
return [self.styles.get(x, self.no_style).negative_prompt for x in styles]
|
||||
|
||||
def apply_styles_to_prompt(self, prompt, styles):
|
||||
return apply_styles_to_prompt(prompt, [self.styles.get(x, self.no_style).prompt for x in styles])
|
||||
|
||||
|
|
|
@ -1,123 +0,0 @@
|
|||
import sys
|
||||
import traceback
|
||||
import cv2
|
||||
import os
|
||||
import contextlib
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import torch
|
||||
import modules.images
|
||||
from modules.shared import cmd_opts, opts, device
|
||||
from modules.swinir_arch import SwinIR as net
|
||||
|
||||
precision_scope = (
|
||||
torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext
|
||||
)
|
||||
|
||||
|
||||
def load_model(filename, scale=4):
|
||||
model = net(
|
||||
upscale=scale,
|
||||
in_chans=3,
|
||||
img_size=64,
|
||||
window_size=8,
|
||||
img_range=1.0,
|
||||
depths=[6, 6, 6, 6, 6, 6, 6, 6, 6],
|
||||
embed_dim=240,
|
||||
num_heads=[8, 8, 8, 8, 8, 8, 8, 8, 8],
|
||||
mlp_ratio=2,
|
||||
upsampler="nearest+conv",
|
||||
resi_connection="3conv",
|
||||
)
|
||||
|
||||
pretrained_model = torch.load(filename)
|
||||
model.load_state_dict(pretrained_model["params_ema"], strict=True)
|
||||
if not cmd_opts.no_half:
|
||||
model = model.half()
|
||||
return model
|
||||
|
||||
|
||||
def load_models(dirname):
|
||||
for file in os.listdir(dirname):
|
||||
path = os.path.join(dirname, file)
|
||||
model_name, extension = os.path.splitext(file)
|
||||
|
||||
if extension != ".pt" and extension != ".pth":
|
||||
continue
|
||||
|
||||
try:
|
||||
modules.shared.sd_upscalers.append(UpscalerSwin(path, model_name))
|
||||
except Exception:
|
||||
print(f"Error loading SwinIR model: {path}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
|
||||
def upscale(
|
||||
img,
|
||||
model,
|
||||
tile=opts.SWIN_tile,
|
||||
tile_overlap=opts.SWIN_tile_overlap,
|
||||
window_size=8,
|
||||
scale=4,
|
||||
):
|
||||
img = np.array(img)
|
||||
img = img[:, :, ::-1]
|
||||
img = np.moveaxis(img, 2, 0) / 255
|
||||
img = torch.from_numpy(img).float()
|
||||
img = img.unsqueeze(0).to(device)
|
||||
with torch.no_grad(), precision_scope("cuda"):
|
||||
_, _, h_old, w_old = img.size()
|
||||
h_pad = (h_old // window_size + 1) * window_size - h_old
|
||||
w_pad = (w_old // window_size + 1) * window_size - w_old
|
||||
img = torch.cat([img, torch.flip(img, [2])], 2)[:, :, : h_old + h_pad, :]
|
||||
img = torch.cat([img, torch.flip(img, [3])], 3)[:, :, :, : w_old + w_pad]
|
||||
output = inference(img, model, tile, tile_overlap, window_size, scale)
|
||||
output = output[..., : h_old * scale, : w_old * scale]
|
||||
output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
||||
if output.ndim == 3:
|
||||
output = np.transpose(
|
||||
output[[2, 1, 0], :, :], (1, 2, 0)
|
||||
) # CHW-RGB to HCW-BGR
|
||||
output = (output * 255.0).round().astype(np.uint8) # float32 to uint8
|
||||
return Image.fromarray(output, "RGB")
|
||||
|
||||
|
||||
def inference(img, model, tile, tile_overlap, window_size, scale):
|
||||
# test the image tile by tile
|
||||
b, c, h, w = img.size()
|
||||
tile = min(tile, h, w)
|
||||
assert tile % window_size == 0, "tile size should be a multiple of window_size"
|
||||
sf = scale
|
||||
|
||||
stride = tile - tile_overlap
|
||||
h_idx_list = list(range(0, h - tile, stride)) + [h - tile]
|
||||
w_idx_list = list(range(0, w - tile, stride)) + [w - tile]
|
||||
E = torch.zeros(b, c, h * sf, w * sf, dtype=torch.half, device=device).type_as(img)
|
||||
W = torch.zeros_like(E, dtype=torch.half, device=device)
|
||||
|
||||
for h_idx in h_idx_list:
|
||||
for w_idx in w_idx_list:
|
||||
in_patch = img[..., h_idx : h_idx + tile, w_idx : w_idx + tile]
|
||||
out_patch = model(in_patch)
|
||||
out_patch_mask = torch.ones_like(out_patch)
|
||||
|
||||
E[
|
||||
..., h_idx * sf : (h_idx + tile) * sf, w_idx * sf : (w_idx + tile) * sf
|
||||
].add_(out_patch)
|
||||
W[
|
||||
..., h_idx * sf : (h_idx + tile) * sf, w_idx * sf : (w_idx + tile) * sf
|
||||
].add_(out_patch_mask)
|
||||
output = E.div_(W)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class UpscalerSwin(modules.images.Upscaler):
|
||||
def __init__(self, filename, title):
|
||||
self.name = title
|
||||
self.model = load_model(filename)
|
||||
|
||||
def do_upscale(self, img):
|
||||
model = self.model.to(device)
|
||||
img = upscale(img, model)
|
||||
return img
|
142
modules/swinir_model.py
Normal file
142
modules/swinir_model.py
Normal file
|
@ -0,0 +1,142 @@
|
|||
import contextlib
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from basicsr.utils.download_util import load_file_from_url
|
||||
from tqdm import tqdm
|
||||
|
||||
from modules import modelloader
|
||||
from modules.paths import models_path
|
||||
from modules.shared import cmd_opts, opts, device
|
||||
from modules.swinir_model_arch import SwinIR as net
|
||||
from modules.upscaler import Upscaler, UpscalerData
|
||||
|
||||
precision_scope = (
|
||||
torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext
|
||||
)
|
||||
|
||||
|
||||
class UpscalerSwinIR(Upscaler):
|
||||
def __init__(self, dirname):
|
||||
self.name = "SwinIR"
|
||||
self.model_url = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0" \
|
||||
"/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR" \
|
||||
"-L_x4_GAN.pth "
|
||||
self.model_name = "SwinIR 4x"
|
||||
self.model_path = os.path.join(models_path, self.name)
|
||||
self.user_path = dirname
|
||||
super().__init__()
|
||||
scalers = []
|
||||
model_files = self.find_models(ext_filter=[".pt", ".pth"])
|
||||
for model in model_files:
|
||||
if "http" in model:
|
||||
name = self.model_name
|
||||
else:
|
||||
name = modelloader.friendly_name(model)
|
||||
model_data = UpscalerData(name, model, self)
|
||||
scalers.append(model_data)
|
||||
self.scalers = scalers
|
||||
|
||||
def do_upscale(self, img, model_file):
|
||||
model = self.load_model(model_file)
|
||||
if model is None:
|
||||
return img
|
||||
model = model.to(device)
|
||||
img = upscale(img, model)
|
||||
try:
|
||||
torch.cuda.empty_cache()
|
||||
except:
|
||||
pass
|
||||
return img
|
||||
|
||||
def load_model(self, path, scale=4):
|
||||
if "http" in path:
|
||||
dl_name = "%s%s" % (self.model_name.replace(" ", "_"), ".pth")
|
||||
filename = load_file_from_url(url=path, model_dir=self.model_path, file_name=dl_name, progress=True)
|
||||
else:
|
||||
filename = path
|
||||
if filename is None or not os.path.exists(filename):
|
||||
return None
|
||||
model = net(
|
||||
upscale=scale,
|
||||
in_chans=3,
|
||||
img_size=64,
|
||||
window_size=8,
|
||||
img_range=1.0,
|
||||
depths=[6, 6, 6, 6, 6, 6, 6, 6, 6],
|
||||
embed_dim=240,
|
||||
num_heads=[8, 8, 8, 8, 8, 8, 8, 8, 8],
|
||||
mlp_ratio=2,
|
||||
upsampler="nearest+conv",
|
||||
resi_connection="3conv",
|
||||
)
|
||||
|
||||
pretrained_model = torch.load(filename)
|
||||
model.load_state_dict(pretrained_model["params_ema"], strict=True)
|
||||
if not cmd_opts.no_half:
|
||||
model = model.half()
|
||||
return model
|
||||
|
||||
|
||||
def upscale(
|
||||
img,
|
||||
model,
|
||||
tile=opts.SWIN_tile,
|
||||
tile_overlap=opts.SWIN_tile_overlap,
|
||||
window_size=8,
|
||||
scale=4,
|
||||
):
|
||||
img = np.array(img)
|
||||
img = img[:, :, ::-1]
|
||||
img = np.moveaxis(img, 2, 0) / 255
|
||||
img = torch.from_numpy(img).float()
|
||||
img = img.unsqueeze(0).to(device)
|
||||
with torch.no_grad(), precision_scope("cuda"):
|
||||
_, _, h_old, w_old = img.size()
|
||||
h_pad = (h_old // window_size + 1) * window_size - h_old
|
||||
w_pad = (w_old // window_size + 1) * window_size - w_old
|
||||
img = torch.cat([img, torch.flip(img, [2])], 2)[:, :, : h_old + h_pad, :]
|
||||
img = torch.cat([img, torch.flip(img, [3])], 3)[:, :, :, : w_old + w_pad]
|
||||
output = inference(img, model, tile, tile_overlap, window_size, scale)
|
||||
output = output[..., : h_old * scale, : w_old * scale]
|
||||
output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
||||
if output.ndim == 3:
|
||||
output = np.transpose(
|
||||
output[[2, 1, 0], :, :], (1, 2, 0)
|
||||
) # CHW-RGB to HCW-BGR
|
||||
output = (output * 255.0).round().astype(np.uint8) # float32 to uint8
|
||||
return Image.fromarray(output, "RGB")
|
||||
|
||||
|
||||
def inference(img, model, tile, tile_overlap, window_size, scale):
|
||||
# test the image tile by tile
|
||||
b, c, h, w = img.size()
|
||||
tile = min(tile, h, w)
|
||||
assert tile % window_size == 0, "tile size should be a multiple of window_size"
|
||||
sf = scale
|
||||
|
||||
stride = tile - tile_overlap
|
||||
h_idx_list = list(range(0, h - tile, stride)) + [h - tile]
|
||||
w_idx_list = list(range(0, w - tile, stride)) + [w - tile]
|
||||
E = torch.zeros(b, c, h * sf, w * sf, dtype=torch.half, device=device).type_as(img)
|
||||
W = torch.zeros_like(E, dtype=torch.half, device=device)
|
||||
|
||||
with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar:
|
||||
for h_idx in h_idx_list:
|
||||
for w_idx in w_idx_list:
|
||||
in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile]
|
||||
out_patch = model(in_patch)
|
||||
out_patch_mask = torch.ones_like(out_patch)
|
||||
|
||||
E[
|
||||
..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
|
||||
].add_(out_patch)
|
||||
W[
|
||||
..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
|
||||
].add_(out_patch_mask)
|
||||
pbar.update(1)
|
||||
output = E.div_(W)
|
||||
|
||||
return output
|
File diff suppressed because it is too large
Load Diff
76
modules/textual_inversion/dataset.py
Normal file
76
modules/textual_inversion/dataset.py
Normal file
|
@ -0,0 +1,76 @@
|
|||
import os
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision import transforms
|
||||
|
||||
import random
|
||||
import tqdm
|
||||
|
||||
|
||||
class PersonalizedBase(Dataset):
|
||||
def __init__(self, data_root, size=None, repeats=100, flip_p=0.5, placeholder_token="*", width=512, height=512, model=None, device=None, template_file=None):
|
||||
|
||||
self.placeholder_token = placeholder_token
|
||||
|
||||
self.size = size
|
||||
self.width = width
|
||||
self.height = height
|
||||
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
|
||||
|
||||
self.dataset = []
|
||||
|
||||
with open(template_file, "r") as file:
|
||||
lines = [x.strip() for x in file.readlines()]
|
||||
|
||||
self.lines = lines
|
||||
|
||||
assert data_root, 'dataset directory not specified'
|
||||
|
||||
self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]
|
||||
print("Preparing dataset...")
|
||||
for path in tqdm.tqdm(self.image_paths):
|
||||
image = Image.open(path)
|
||||
image = image.convert('RGB')
|
||||
image = image.resize((self.width, self.height), PIL.Image.BICUBIC)
|
||||
|
||||
filename = os.path.basename(path)
|
||||
filename_tokens = os.path.splitext(filename)[0].replace('_', '-').replace(' ', '-').split('-')
|
||||
filename_tokens = [token for token in filename_tokens if token.isalpha()]
|
||||
|
||||
npimage = np.array(image).astype(np.uint8)
|
||||
npimage = (npimage / 127.5 - 1.0).astype(np.float32)
|
||||
|
||||
torchdata = torch.from_numpy(npimage).to(device=device, dtype=torch.float32)
|
||||
torchdata = torch.moveaxis(torchdata, 2, 0)
|
||||
|
||||
init_latent = model.get_first_stage_encoding(model.encode_first_stage(torchdata.unsqueeze(dim=0))).squeeze()
|
||||
|
||||
self.dataset.append((init_latent, filename_tokens))
|
||||
|
||||
self.length = len(self.dataset) * repeats
|
||||
|
||||
self.initial_indexes = np.arange(self.length) % len(self.dataset)
|
||||
self.indexes = None
|
||||
self.shuffle()
|
||||
|
||||
def shuffle(self):
|
||||
self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0])]
|
||||
|
||||
def __len__(self):
|
||||
return self.length
|
||||
|
||||
def __getitem__(self, i):
|
||||
if i % len(self.dataset) == 0:
|
||||
self.shuffle()
|
||||
|
||||
index = self.indexes[i % len(self.indexes)]
|
||||
x, filename_tokens = self.dataset[index]
|
||||
|
||||
text = random.choice(self.lines)
|
||||
text = text.replace("[name]", self.placeholder_token)
|
||||
text = text.replace("[filewords]", ' '.join(filename_tokens))
|
||||
|
||||
return x, text
|
258
modules/textual_inversion/textual_inversion.py
Normal file
258
modules/textual_inversion/textual_inversion.py
Normal file
|
@ -0,0 +1,258 @@
|
|||
import os
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
import torch
|
||||
import tqdm
|
||||
import html
|
||||
import datetime
|
||||
|
||||
from modules import shared, devices, sd_hijack, processing
|
||||
import modules.textual_inversion.dataset
|
||||
|
||||
|
||||
class Embedding:
|
||||
def __init__(self, vec, name, step=None):
|
||||
self.vec = vec
|
||||
self.name = name
|
||||
self.step = step
|
||||
self.cached_checksum = None
|
||||
|
||||
def save(self, filename):
|
||||
embedding_data = {
|
||||
"string_to_token": {"*": 265},
|
||||
"string_to_param": {"*": self.vec},
|
||||
"name": self.name,
|
||||
"step": self.step,
|
||||
}
|
||||
|
||||
torch.save(embedding_data, filename)
|
||||
|
||||
def checksum(self):
|
||||
if self.cached_checksum is not None:
|
||||
return self.cached_checksum
|
||||
|
||||
def const_hash(a):
|
||||
r = 0
|
||||
for v in a:
|
||||
r = (r * 281 ^ int(v) * 997) & 0xFFFFFFFF
|
||||
return r
|
||||
|
||||
self.cached_checksum = f'{const_hash(self.vec.reshape(-1) * 100) & 0xffff:04x}'
|
||||
return self.cached_checksum
|
||||
|
||||
class EmbeddingDatabase:
|
||||
def __init__(self, embeddings_dir):
|
||||
self.ids_lookup = {}
|
||||
self.word_embeddings = {}
|
||||
self.dir_mtime = None
|
||||
self.embeddings_dir = embeddings_dir
|
||||
|
||||
def register_embedding(self, embedding, model):
|
||||
|
||||
self.word_embeddings[embedding.name] = embedding
|
||||
|
||||
ids = model.cond_stage_model.tokenizer([embedding.name], add_special_tokens=False)['input_ids'][0]
|
||||
|
||||
first_id = ids[0]
|
||||
if first_id not in self.ids_lookup:
|
||||
self.ids_lookup[first_id] = []
|
||||
self.ids_lookup[first_id].append((ids, embedding))
|
||||
|
||||
return embedding
|
||||
|
||||
def load_textual_inversion_embeddings(self):
|
||||
mt = os.path.getmtime(self.embeddings_dir)
|
||||
if self.dir_mtime is not None and mt <= self.dir_mtime:
|
||||
return
|
||||
|
||||
self.dir_mtime = mt
|
||||
self.ids_lookup.clear()
|
||||
self.word_embeddings.clear()
|
||||
|
||||
def process_file(path, filename):
|
||||
name = os.path.splitext(filename)[0]
|
||||
|
||||
data = torch.load(path, map_location="cpu")
|
||||
|
||||
# textual inversion embeddings
|
||||
if 'string_to_param' in data:
|
||||
param_dict = data['string_to_param']
|
||||
if hasattr(param_dict, '_parameters'):
|
||||
param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11
|
||||
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
|
||||
emb = next(iter(param_dict.items()))[1]
|
||||
# diffuser concepts
|
||||
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:
|
||||
assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
|
||||
|
||||
emb = next(iter(data.values()))
|
||||
if len(emb.shape) == 1:
|
||||
emb = emb.unsqueeze(0)
|
||||
else:
|
||||
raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")
|
||||
|
||||
vec = emb.detach().to(devices.device, dtype=torch.float32)
|
||||
embedding = Embedding(vec, name)
|
||||
embedding.step = data.get('step', None)
|
||||
self.register_embedding(embedding, shared.sd_model)
|
||||
|
||||
for fn in os.listdir(self.embeddings_dir):
|
||||
try:
|
||||
fullfn = os.path.join(self.embeddings_dir, fn)
|
||||
|
||||
if os.stat(fullfn).st_size == 0:
|
||||
continue
|
||||
|
||||
process_file(fullfn, fn)
|
||||
except Exception:
|
||||
print(f"Error loading emedding {fn}:", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
continue
|
||||
|
||||
print(f"Loaded a total of {len(self.word_embeddings)} textual inversion embeddings.")
|
||||
|
||||
def find_embedding_at_position(self, tokens, offset):
|
||||
token = tokens[offset]
|
||||
possible_matches = self.ids_lookup.get(token, None)
|
||||
|
||||
if possible_matches is None:
|
||||
return None
|
||||
|
||||
for ids, embedding in possible_matches:
|
||||
if tokens[offset:offset + len(ids)] == ids:
|
||||
return embedding
|
||||
|
||||
return None
|
||||
|
||||
|
||||
|
||||
def create_embedding(name, num_vectors_per_token):
|
||||
init_text = '*'
|
||||
|
||||
cond_model = shared.sd_model.cond_stage_model
|
||||
embedding_layer = cond_model.wrapped.transformer.text_model.embeddings
|
||||
|
||||
ids = cond_model.tokenizer(init_text, max_length=num_vectors_per_token, return_tensors="pt", add_special_tokens=False)["input_ids"]
|
||||
embedded = embedding_layer(ids.to(devices.device)).squeeze(0)
|
||||
vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device)
|
||||
|
||||
for i in range(num_vectors_per_token):
|
||||
vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token]
|
||||
|
||||
fn = os.path.join(shared.cmd_opts.embeddings_dir, f"{name}.pt")
|
||||
assert not os.path.exists(fn), f"file {fn} already exists"
|
||||
|
||||
embedding = Embedding(vec, name)
|
||||
embedding.step = 0
|
||||
embedding.save(fn)
|
||||
|
||||
return fn
|
||||
|
||||
|
||||
def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps, create_image_every, save_embedding_every, template_file):
|
||||
assert embedding_name, 'embedding not selected'
|
||||
|
||||
shared.state.textinfo = "Initializing textual inversion training..."
|
||||
shared.state.job_count = steps
|
||||
|
||||
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
|
||||
|
||||
log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%d-%m"), embedding_name)
|
||||
|
||||
if save_embedding_every > 0:
|
||||
embedding_dir = os.path.join(log_directory, "embeddings")
|
||||
os.makedirs(embedding_dir, exist_ok=True)
|
||||
else:
|
||||
embedding_dir = None
|
||||
|
||||
if create_image_every > 0:
|
||||
images_dir = os.path.join(log_directory, "images")
|
||||
os.makedirs(images_dir, exist_ok=True)
|
||||
else:
|
||||
images_dir = None
|
||||
|
||||
cond_model = shared.sd_model.cond_stage_model
|
||||
|
||||
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
||||
with torch.autocast("cuda"):
|
||||
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, size=512, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file)
|
||||
|
||||
hijack = sd_hijack.model_hijack
|
||||
|
||||
embedding = hijack.embedding_db.word_embeddings[embedding_name]
|
||||
embedding.vec.requires_grad = True
|
||||
|
||||
optimizer = torch.optim.AdamW([embedding.vec], lr=learn_rate)
|
||||
|
||||
losses = torch.zeros((32,))
|
||||
|
||||
last_saved_file = "<none>"
|
||||
last_saved_image = "<none>"
|
||||
|
||||
ititial_step = embedding.step or 0
|
||||
if ititial_step > steps:
|
||||
return embedding, filename
|
||||
|
||||
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
|
||||
for i, (x, text) in pbar:
|
||||
embedding.step = i + ititial_step
|
||||
|
||||
if embedding.step > steps:
|
||||
break
|
||||
|
||||
if shared.state.interrupted:
|
||||
break
|
||||
|
||||
with torch.autocast("cuda"):
|
||||
c = cond_model([text])
|
||||
loss = shared.sd_model(x.unsqueeze(0), c)[0]
|
||||
|
||||
losses[embedding.step % losses.shape[0]] = loss.item()
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
pbar.set_description(f"loss: {losses.mean():.7f}")
|
||||
|
||||
if embedding.step > 0 and embedding_dir is not None and embedding.step % save_embedding_every == 0:
|
||||
last_saved_file = os.path.join(embedding_dir, f'{embedding_name}-{embedding.step}.pt')
|
||||
embedding.save(last_saved_file)
|
||||
|
||||
if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0:
|
||||
last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png')
|
||||
|
||||
p = processing.StableDiffusionProcessingTxt2Img(
|
||||
sd_model=shared.sd_model,
|
||||
prompt=text,
|
||||
steps=20,
|
||||
do_not_save_grid=True,
|
||||
do_not_save_samples=True,
|
||||
)
|
||||
|
||||
processed = processing.process_images(p)
|
||||
image = processed.images[0]
|
||||
|
||||
shared.state.current_image = image
|
||||
image.save(last_saved_image)
|
||||
|
||||
last_saved_image += f", prompt: {text}"
|
||||
|
||||
shared.state.job_no = embedding.step
|
||||
|
||||
shared.state.textinfo = f"""
|
||||
<p>
|
||||
Loss: {losses.mean():.7f}<br/>
|
||||
Step: {embedding.step}<br/>
|
||||
Last prompt: {html.escape(text)}<br/>
|
||||
Last saved embedding: {html.escape(last_saved_file)}<br/>
|
||||
Last saved image: {html.escape(last_saved_image)}<br/>
|
||||
</p>
|
||||
"""
|
||||
|
||||
embedding.cached_checksum = None
|
||||
embedding.save(filename)
|
||||
|
||||
return embedding, filename
|
||||
|
32
modules/textual_inversion/ui.py
Normal file
32
modules/textual_inversion/ui.py
Normal file
|
@ -0,0 +1,32 @@
|
|||
import html
|
||||
|
||||
import gradio as gr
|
||||
|
||||
import modules.textual_inversion.textual_inversion as ti
|
||||
from modules import sd_hijack, shared
|
||||
|
||||
|
||||
def create_embedding(name, nvpt):
|
||||
filename = ti.create_embedding(name, nvpt)
|
||||
|
||||
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
|
||||
|
||||
return gr.Dropdown.update(choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())), f"Created: {filename}", ""
|
||||
|
||||
|
||||
def train_embedding(*args):
|
||||
|
||||
try:
|
||||
sd_hijack.undo_optimizations()
|
||||
|
||||
embedding, filename = ti.train_embedding(*args)
|
||||
|
||||
res = f"""
|
||||
Training {'interrupted' if shared.state.interrupted else 'finished'} after {embedding.step} steps.
|
||||
Embedding saved to {html.escape(filename)}
|
||||
"""
|
||||
return res, ""
|
||||
except Exception:
|
||||
raise
|
||||
finally:
|
||||
sd_hijack.apply_optimizations()
|
164
modules/ui.py
164
modules/ui.py
|
@ -15,11 +15,13 @@ import subprocess as sp
|
|||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image, PngImagePlugin
|
||||
import piexif
|
||||
|
||||
import gradio as gr
|
||||
import gradio.utils
|
||||
import gradio.routes
|
||||
|
||||
from modules import sd_hijack
|
||||
from modules.paths import script_path
|
||||
from modules.shared import opts, cmd_opts
|
||||
import modules.shared as shared
|
||||
|
@ -32,6 +34,7 @@ import modules.codeformer_model
|
|||
import modules.styles
|
||||
import modules.generation_parameters_copypaste
|
||||
from modules.images import apply_filename_pattern, get_next_sequence_number
|
||||
import modules.textual_inversion.ui
|
||||
|
||||
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the bowser will not show any UI
|
||||
mimetypes.init()
|
||||
|
@ -129,27 +132,37 @@ def save_files(js_data, images, index):
|
|||
writer = csv.writer(file)
|
||||
if at_start:
|
||||
writer.writerow(["prompt", "seed", "width", "height", "sampler", "cfgs", "steps", "filename", "negative_prompt"])
|
||||
|
||||
file_decoration = opts.samples_filename_pattern or "[seed]-[prompt_spaces]"
|
||||
if file_decoration != "":
|
||||
file_decoration = "-" + file_decoration.lower()
|
||||
file_decoration = apply_filename_pattern(file_decoration, p, p.seed, p.prompt)
|
||||
truncated = (file_decoration[:240] + '..') if len(file_decoration) > 240 else file_decoration
|
||||
filename_base = truncated
|
||||
extension = opts.samples_format.lower()
|
||||
|
||||
basecount = get_next_sequence_number(path, "")
|
||||
for i, filedata in enumerate(images):
|
||||
file_number = f"{basecount+i:05}"
|
||||
filename = file_number + filename_base + ".png"
|
||||
filename = file_number + filename_base + f".{extension}"
|
||||
filepath = os.path.join(path, filename)
|
||||
|
||||
|
||||
if filedata.startswith("data:image/png;base64,"):
|
||||
filedata = filedata[len("data:image/png;base64,"):]
|
||||
|
||||
pnginfo = PngImagePlugin.PngInfo()
|
||||
pnginfo.add_text('parameters', infotexts[i])
|
||||
|
||||
image = Image.open(io.BytesIO(base64.decodebytes(filedata.encode('utf-8'))))
|
||||
image.save(filepath, quality=opts.jpeg_quality, pnginfo=pnginfo)
|
||||
if opts.enable_pnginfo and extension == 'png':
|
||||
pnginfo = PngImagePlugin.PngInfo()
|
||||
pnginfo.add_text('parameters', infotexts[i])
|
||||
image.save(filepath, pnginfo=pnginfo)
|
||||
else:
|
||||
image.save(filepath, quality=opts.jpeg_quality)
|
||||
|
||||
if opts.enable_pnginfo and extension in ("jpg", "jpeg", "webp"):
|
||||
piexif.insert(piexif.dump({"Exif": {
|
||||
piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(infotexts[i], encoding="unicode")
|
||||
}}), filepath)
|
||||
|
||||
filenames.append(filename)
|
||||
|
||||
|
@ -158,8 +171,8 @@ def save_files(js_data, images, index):
|
|||
return '', '', plaintext_to_html(f"Saved: {filenames[0]}")
|
||||
|
||||
|
||||
def wrap_gradio_call(func):
|
||||
def f(*args, **kwargs):
|
||||
def wrap_gradio_call(func, extra_outputs=None):
|
||||
def f(*args, extra_outputs_array=extra_outputs, **kwargs):
|
||||
run_memmon = opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled
|
||||
if run_memmon:
|
||||
shared.mem_mon.monitor()
|
||||
|
@ -175,7 +188,10 @@ def wrap_gradio_call(func):
|
|||
shared.state.job = ""
|
||||
shared.state.job_count = 0
|
||||
|
||||
res = [None, '', f"<div class='error'>{plaintext_to_html(type(e).__name__+': '+str(e))}</div>"]
|
||||
if extra_outputs_array is None:
|
||||
extra_outputs_array = [None, '']
|
||||
|
||||
res = extra_outputs_array + [f"<div class='error'>{plaintext_to_html(type(e).__name__+': '+str(e))}</div>"]
|
||||
|
||||
elapsed = time.perf_counter() - t
|
||||
|
||||
|
@ -195,6 +211,7 @@ def wrap_gradio_call(func):
|
|||
res[-1] += f"<div class='performance'><p class='time'>Time taken: <wbr>{elapsed:.2f}s</p>{vram_html}</div>"
|
||||
|
||||
shared.state.interrupted = False
|
||||
shared.state.job_count = 0
|
||||
|
||||
return tuple(res)
|
||||
|
||||
|
@ -203,7 +220,7 @@ def wrap_gradio_call(func):
|
|||
|
||||
def check_progress_call(id_part):
|
||||
if shared.state.job_count == 0:
|
||||
return "", gr_show(False), gr_show(False)
|
||||
return "", gr_show(False), gr_show(False), gr_show(False)
|
||||
|
||||
progress = 0
|
||||
|
||||
|
@ -235,13 +252,19 @@ def check_progress_call(id_part):
|
|||
else:
|
||||
preview_visibility = gr_show(True)
|
||||
|
||||
return f"<span id='{id_part}_progress_span' style='display: none'>{time.time()}</span><p>{progressbar}</p>", preview_visibility, image
|
||||
if shared.state.textinfo is not None:
|
||||
textinfo_result = gr.HTML.update(value=shared.state.textinfo, visible=True)
|
||||
else:
|
||||
textinfo_result = gr_show(False)
|
||||
|
||||
return f"<span id='{id_part}_progress_span' style='display: none'>{time.time()}</span><p>{progressbar}</p>", preview_visibility, image, textinfo_result
|
||||
|
||||
|
||||
def check_progress_call_initial(id_part):
|
||||
shared.state.job_count = -1
|
||||
shared.state.current_latent = None
|
||||
shared.state.current_image = None
|
||||
shared.state.textinfo = None
|
||||
|
||||
return check_progress_call(id_part)
|
||||
|
||||
|
@ -396,7 +419,7 @@ def create_toprow(is_img2img):
|
|||
with gr.Column(scale=1):
|
||||
with gr.Row():
|
||||
interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt")
|
||||
submit = gr.Button('Generate', elem_id="generate", variant='primary')
|
||||
submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary')
|
||||
|
||||
interrupt.click(
|
||||
fn=lambda: shared.state.interrupt(),
|
||||
|
@ -415,13 +438,16 @@ def create_toprow(is_img2img):
|
|||
return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, interrogate, prompt_style_apply, save_style, paste
|
||||
|
||||
|
||||
def setup_progressbar(progressbar, preview, id_part):
|
||||
def setup_progressbar(progressbar, preview, id_part, textinfo=None):
|
||||
if textinfo is None:
|
||||
textinfo = gr.HTML(visible=False)
|
||||
|
||||
check_progress = gr.Button('Check progress', elem_id=f"{id_part}_check_progress", visible=False)
|
||||
check_progress.click(
|
||||
fn=lambda: check_progress_call(id_part),
|
||||
show_progress=False,
|
||||
inputs=[],
|
||||
outputs=[progressbar, preview, preview],
|
||||
outputs=[progressbar, preview, preview, textinfo],
|
||||
)
|
||||
|
||||
check_progress_initial = gr.Button('Check progress (first)', elem_id=f"{id_part}_check_progress_initial", visible=False)
|
||||
|
@ -429,11 +455,14 @@ def setup_progressbar(progressbar, preview, id_part):
|
|||
fn=lambda: check_progress_call_initial(id_part),
|
||||
show_progress=False,
|
||||
inputs=[],
|
||||
outputs=[progressbar, preview, preview],
|
||||
outputs=[progressbar, preview, preview, textinfo],
|
||||
)
|
||||
|
||||
|
||||
def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
|
||||
def create_ui(wrap_gradio_gpu_call):
|
||||
import modules.img2img
|
||||
import modules.txt2img
|
||||
|
||||
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
|
||||
txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, txt2img_prompt_style_apply, txt2img_save_style, paste = create_toprow(is_img2img=False)
|
||||
dummy_component = gr.Label(visible=False)
|
||||
|
@ -499,7 +528,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
|
|||
connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
|
||||
|
||||
txt2img_args = dict(
|
||||
fn=txt2img,
|
||||
fn=wrap_gradio_gpu_call(modules.txt2img.txt2img),
|
||||
_js="submit",
|
||||
inputs=[
|
||||
txt2img_prompt,
|
||||
|
@ -615,7 +644,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
|
|||
mask_mode = gr.Radio(label="Mask mode", show_label=False, choices=["Draw mask", "Upload mask"], type="index", value="Draw mask", elem_id="mask_mode")
|
||||
inpainting_mask_invert = gr.Radio(label='Masking mode', show_label=False, choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index")
|
||||
|
||||
inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='fill', type="index")
|
||||
inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='original', type="index")
|
||||
|
||||
with gr.Row():
|
||||
inpaint_full_res = gr.Checkbox(label='Inpaint at full resolution', value=False)
|
||||
|
@ -691,7 +720,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
|
|||
)
|
||||
|
||||
img2img_args = dict(
|
||||
fn=img2img,
|
||||
fn=wrap_gradio_gpu_call(modules.img2img.img2img),
|
||||
_js="submit_img2img",
|
||||
inputs=[
|
||||
dummy_component,
|
||||
|
@ -844,7 +873,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
|
|||
open_extras_folder = gr.Button('Open output directory', elem_id=button_id)
|
||||
|
||||
submit.click(
|
||||
fn=run_extras,
|
||||
fn=wrap_gradio_gpu_call(modules.extras.run_extras),
|
||||
_js="get_extras_tab_index",
|
||||
inputs=[
|
||||
dummy_component,
|
||||
|
@ -894,7 +923,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
|
|||
pnginfo_send_to_img2img = gr.Button('Send to img2img')
|
||||
|
||||
image.change(
|
||||
fn=wrap_gradio_call(run_pnginfo),
|
||||
fn=wrap_gradio_call(modules.extras.run_pnginfo),
|
||||
inputs=[image],
|
||||
outputs=[html, generation_info, html2],
|
||||
)
|
||||
|
@ -903,7 +932,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
|
|||
with gr.Row().style(equal_height=False):
|
||||
with gr.Column(variant='panel'):
|
||||
gr.HTML(value="<p>A merger of the two checkpoints will be generated in your <b>checkpoint</b> directory.</p>")
|
||||
|
||||
|
||||
with gr.Row():
|
||||
primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary Model Name")
|
||||
secondary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary Model Name")
|
||||
|
@ -912,10 +941,96 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
|
|||
interp_method = gr.Radio(choices=["Weighted Sum", "Sigmoid", "Inverse Sigmoid"], value="Weighted Sum", label="Interpolation Method")
|
||||
save_as_half = gr.Checkbox(value=False, label="Safe as float16")
|
||||
modelmerger_merge = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary')
|
||||
|
||||
|
||||
with gr.Column(variant='panel'):
|
||||
submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False)
|
||||
|
||||
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
|
||||
|
||||
with gr.Blocks() as textual_inversion_interface:
|
||||
with gr.Row().style(equal_height=False):
|
||||
with gr.Column():
|
||||
with gr.Group():
|
||||
gr.HTML(value="<p style='margin-bottom: 0.7em'>Create a new embedding</p>")
|
||||
|
||||
new_embedding_name = gr.Textbox(label="Name")
|
||||
nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1)
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column(scale=3):
|
||||
gr.HTML(value="")
|
||||
|
||||
with gr.Column():
|
||||
create_embedding = gr.Button(value="Create", variant='primary')
|
||||
|
||||
with gr.Group():
|
||||
gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding; must specify a directory with a set of 512x512 images</p>")
|
||||
train_embedding_name = gr.Dropdown(label='Embedding', choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
|
||||
learn_rate = gr.Number(label='Learning rate', value=5.0e-03)
|
||||
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
|
||||
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
|
||||
template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"))
|
||||
steps = gr.Number(label='Max steps', value=100000, precision=0)
|
||||
create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=1000, precision=0)
|
||||
save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=1000, precision=0)
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column(scale=2):
|
||||
gr.HTML(value="")
|
||||
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
interrupt_training = gr.Button(value="Interrupt")
|
||||
train_embedding = gr.Button(value="Train", variant='primary')
|
||||
|
||||
with gr.Column():
|
||||
progressbar = gr.HTML(elem_id="ti_progressbar")
|
||||
ti_output = gr.Text(elem_id="ti_output", value="", show_label=False)
|
||||
|
||||
ti_gallery = gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery').style(grid=4)
|
||||
ti_preview = gr.Image(elem_id='ti_preview', visible=False)
|
||||
ti_progress = gr.HTML(elem_id="ti_progress", value="")
|
||||
ti_outcome = gr.HTML(elem_id="ti_error", value="")
|
||||
setup_progressbar(progressbar, ti_preview, 'ti', textinfo=ti_progress)
|
||||
|
||||
create_embedding.click(
|
||||
fn=modules.textual_inversion.ui.create_embedding,
|
||||
inputs=[
|
||||
new_embedding_name,
|
||||
nvpt,
|
||||
],
|
||||
outputs=[
|
||||
train_embedding_name,
|
||||
ti_output,
|
||||
ti_outcome,
|
||||
]
|
||||
)
|
||||
|
||||
train_embedding.click(
|
||||
fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.train_embedding, extra_outputs=[gr.update()]),
|
||||
_js="start_training_textual_inversion",
|
||||
inputs=[
|
||||
train_embedding_name,
|
||||
learn_rate,
|
||||
dataset_directory,
|
||||
log_directory,
|
||||
steps,
|
||||
create_image_every,
|
||||
save_embedding_every,
|
||||
template_file,
|
||||
],
|
||||
outputs=[
|
||||
ti_output,
|
||||
ti_outcome,
|
||||
]
|
||||
)
|
||||
|
||||
interrupt_training.click(
|
||||
fn=lambda: shared.state.interrupt(),
|
||||
inputs=[],
|
||||
outputs=[],
|
||||
)
|
||||
|
||||
def create_setting_component(key):
|
||||
def fun():
|
||||
return opts.data[key] if key in opts.data else opts.data_labels[key].default
|
||||
|
@ -1027,6 +1142,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
|
|||
(extras_interface, "Extras", "extras"),
|
||||
(pnginfo_interface, "PNG Info", "pnginfo"),
|
||||
(modelmerger_interface, "Checkpoint Merger", "modelmerger"),
|
||||
(textual_inversion_interface, "Textual inversion", "ti"),
|
||||
(settings_interface, "Settings", "settings"),
|
||||
]
|
||||
|
||||
|
@ -1060,11 +1176,11 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
|
|||
|
||||
def modelmerger(*args):
|
||||
try:
|
||||
results = run_modelmerger(*args)
|
||||
results = modules.extras.run_modelmerger(*args)
|
||||
except Exception as e:
|
||||
print("Error loading/saving model file:", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
modules.sd_models.list_models() #To remove the potentially missing models from the list
|
||||
modules.sd_models.list_models() # to remove the potentially missing models from the list
|
||||
return ["Error loading/saving model file. It doesn't exist or the name contains illegal characters"] + [gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(3)]
|
||||
return results
|
||||
|
||||
|
|
121
modules/upscaler.py
Normal file
121
modules/upscaler.py
Normal file
|
@ -0,0 +1,121 @@
|
|||
import os
|
||||
from abc import abstractmethod
|
||||
|
||||
import PIL
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
import modules.shared
|
||||
from modules import modelloader, shared
|
||||
|
||||
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
|
||||
from modules.paths import models_path
|
||||
|
||||
|
||||
class Upscaler:
|
||||
name = None
|
||||
model_path = None
|
||||
model_name = None
|
||||
model_url = None
|
||||
enable = True
|
||||
filter = None
|
||||
model = None
|
||||
user_path = None
|
||||
scalers: []
|
||||
tile = True
|
||||
|
||||
def __init__(self, create_dirs=False):
|
||||
self.mod_pad_h = None
|
||||
self.tile_size = modules.shared.opts.ESRGAN_tile
|
||||
self.tile_pad = modules.shared.opts.ESRGAN_tile_overlap
|
||||
self.device = modules.shared.device
|
||||
self.img = None
|
||||
self.output = None
|
||||
self.scale = 1
|
||||
self.half = not modules.shared.cmd_opts.no_half
|
||||
self.pre_pad = 0
|
||||
self.mod_scale = None
|
||||
if self.name is not None and create_dirs:
|
||||
self.model_path = os.path.join(models_path, self.name)
|
||||
if not os.path.exists(self.model_path):
|
||||
os.makedirs(self.model_path)
|
||||
|
||||
try:
|
||||
import cv2
|
||||
self.can_tile = True
|
||||
except:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def do_upscale(self, img: PIL.Image, selected_model: str):
|
||||
return img
|
||||
|
||||
def upscale(self, img: PIL.Image, scale: int, selected_model: str = None):
|
||||
self.scale = scale
|
||||
dest_w = img.width * scale
|
||||
dest_h = img.height * scale
|
||||
for i in range(3):
|
||||
if img.width >= dest_w and img.height >= dest_h:
|
||||
break
|
||||
img = self.do_upscale(img, selected_model)
|
||||
if img.width != dest_w or img.height != dest_h:
|
||||
img = img.resize((int(dest_w), int(dest_h)), resample=LANCZOS)
|
||||
|
||||
return img
|
||||
|
||||
@abstractmethod
|
||||
def load_model(self, path: str):
|
||||
pass
|
||||
|
||||
def find_models(self, ext_filter=None) -> list:
|
||||
return modelloader.load_models(model_path=self.model_path, model_url=self.model_url, command_path=self.user_path)
|
||||
|
||||
def update_status(self, prompt):
|
||||
print(f"\nextras: {prompt}", file=shared.progress_print_out)
|
||||
|
||||
|
||||
class UpscalerData:
|
||||
name = None
|
||||
data_path = None
|
||||
scale: int = 4
|
||||
scaler: Upscaler = None
|
||||
model: None
|
||||
|
||||
def __init__(self, name: str, path: str, upscaler: Upscaler = None, scale: int = 4, model=None):
|
||||
self.name = name
|
||||
self.data_path = path
|
||||
self.scaler = upscaler
|
||||
self.scale = scale
|
||||
self.model = model
|
||||
|
||||
|
||||
class UpscalerNone(Upscaler):
|
||||
name = "None"
|
||||
scalers = []
|
||||
|
||||
def load_model(self, path):
|
||||
pass
|
||||
|
||||
def do_upscale(self, img, selected_model=None):
|
||||
return img
|
||||
|
||||
def __init__(self, dirname=None):
|
||||
super().__init__(False)
|
||||
self.scalers = [UpscalerData("None", None, self)]
|
||||
|
||||
|
||||
class UpscalerLanczos(Upscaler):
|
||||
scalers = []
|
||||
|
||||
def do_upscale(self, img, selected_model=None):
|
||||
return img.resize((int(img.width * self.scale), int(img.height * self.scale)), resample=LANCZOS)
|
||||
|
||||
def load_model(self, _):
|
||||
pass
|
||||
|
||||
def __init__(self, dirname=None):
|
||||
super().__init__(False)
|
||||
self.name = "Lanczos"
|
||||
self.scalers = [UpscalerData("Lanczos", None, self)]
|
||||
|
|
@ -4,7 +4,7 @@ fairscale==0.4.4
|
|||
fonts
|
||||
font-roboto
|
||||
gfpgan
|
||||
gradio
|
||||
gradio==3.4b3
|
||||
invisible-watermark
|
||||
numpy
|
||||
omegaconf
|
||||
|
|
|
@ -11,46 +11,8 @@ from modules import images, processing, devices
|
|||
from modules.processing import Processed, process_images
|
||||
from modules.shared import opts, cmd_opts, state
|
||||
|
||||
# https://github.com/parlance-zz/g-diffuser-bot
|
||||
def expand(x, dir, amount, power=0.75):
|
||||
is_left = dir == 3
|
||||
is_right = dir == 1
|
||||
is_up = dir == 0
|
||||
is_down = dir == 2
|
||||
|
||||
if is_left or is_right:
|
||||
noise = np.zeros((x.shape[0], amount, 3), dtype=float)
|
||||
indexes = np.random.random((x.shape[0], amount)) ** power * (1 - np.arange(amount) / amount)
|
||||
if is_right:
|
||||
indexes = 1 - indexes
|
||||
indexes = (indexes * (x.shape[1] - 1)).astype(int)
|
||||
|
||||
for row in range(x.shape[0]):
|
||||
if is_left:
|
||||
noise[row] = x[row][indexes[row]]
|
||||
else:
|
||||
noise[row] = np.flip(x[row][indexes[row]], axis=0)
|
||||
|
||||
x = np.concatenate([noise, x] if is_left else [x, noise], axis=1)
|
||||
return x
|
||||
|
||||
if is_up or is_down:
|
||||
noise = np.zeros((amount, x.shape[1], 3), dtype=float)
|
||||
indexes = np.random.random((x.shape[1], amount)) ** power * (1 - np.arange(amount) / amount)
|
||||
if is_down:
|
||||
indexes = 1 - indexes
|
||||
indexes = (indexes * x.shape[0] - 1).astype(int)
|
||||
|
||||
for row in range(x.shape[1]):
|
||||
if is_up:
|
||||
noise[:, row] = x[:, row][indexes[row]]
|
||||
else:
|
||||
noise[:, row] = np.flip(x[:, row][indexes[row]], axis=0)
|
||||
|
||||
x = np.concatenate([noise, x] if is_up else [x, noise], axis=0)
|
||||
return x
|
||||
|
||||
|
||||
# this function is taken from https://github.com/parlance-zz/g-diffuser-bot
|
||||
def get_matched_noise(_np_src_image, np_mask_rgb, noise_q=1, color_variation=0.05):
|
||||
# helper fft routines that keep ortho normalization and auto-shift before and after fft
|
||||
def _fft2(data):
|
||||
|
|
|
@ -34,7 +34,7 @@ class Script(scripts.Script):
|
|||
seed = p.seed
|
||||
|
||||
init_img = p.init_images[0]
|
||||
img = upscaler.upscale(init_img, init_img.width * 2, init_img.height * 2)
|
||||
img = upscaler.scaler.upscale(init_img, 2, upscaler.data_path)
|
||||
|
||||
devices.torch_gc()
|
||||
|
||||
|
|
|
@ -45,11 +45,8 @@ def apply_sampler(p, x, xs):
|
|||
|
||||
|
||||
def apply_checkpoint(p, x, xs):
|
||||
applicable = [info for info in modules.sd_models.checkpoints_list.values() if x in info.title]
|
||||
assert len(applicable) > 0, f'Checkpoint {x} for found'
|
||||
|
||||
info = applicable[0]
|
||||
|
||||
info = modules.sd_models.get_closet_checkpoint_match(x)
|
||||
assert info is not None, f'Checkpoint for {x} not found'
|
||||
modules.sd_models.reload_model_weights(shared.sd_model, info)
|
||||
|
||||
|
||||
|
@ -159,6 +156,9 @@ class Script(scripts.Script):
|
|||
p.batch_size = 1
|
||||
|
||||
def process_axis(opt, vals):
|
||||
if opt.label == 'Nothing':
|
||||
return [0]
|
||||
|
||||
valslist = [x.strip() for x in vals.split(",")]
|
||||
|
||||
if opt.type == int:
|
||||
|
|
12
style.css
12
style.css
|
@ -23,7 +23,7 @@
|
|||
text-align: right;
|
||||
}
|
||||
|
||||
#generate{
|
||||
#txt2img_generate, #img2img_generate {
|
||||
min-height: 4.5em;
|
||||
}
|
||||
|
||||
|
@ -157,7 +157,7 @@ button{
|
|||
max-width: 10em;
|
||||
}
|
||||
|
||||
#txt2img_preview, #img2img_preview{
|
||||
#txt2img_preview, #img2img_preview, #ti_preview{
|
||||
position: absolute;
|
||||
width: 320px;
|
||||
left: 0;
|
||||
|
@ -172,18 +172,18 @@ button{
|
|||
}
|
||||
|
||||
@media screen and (min-width: 768px) {
|
||||
#txt2img_preview, #img2img_preview {
|
||||
#txt2img_preview, #img2img_preview, #ti_preview {
|
||||
position: absolute;
|
||||
}
|
||||
}
|
||||
|
||||
@media screen and (max-width: 767px) {
|
||||
#txt2img_preview, #img2img_preview {
|
||||
#txt2img_preview, #img2img_preview, #ti_preview {
|
||||
position: relative;
|
||||
}
|
||||
}
|
||||
|
||||
#txt2img_preview div.left-0.top-0, #img2img_preview div.left-0.top-0{
|
||||
#txt2img_preview div.left-0.top-0, #img2img_preview div.left-0.top-0, #ti_preview div.left-0.top-0{
|
||||
display: none;
|
||||
}
|
||||
|
||||
|
@ -247,7 +247,7 @@ input[type="range"]{
|
|||
#txt2img_negative_prompt, #img2img_negative_prompt{
|
||||
}
|
||||
|
||||
#txt2img_progressbar, #img2img_progressbar{
|
||||
#txt2img_progressbar, #img2img_progressbar, #ti_progressbar{
|
||||
position: absolute;
|
||||
z-index: 1000;
|
||||
right: 0;
|
||||
|
|
19
textual_inversion_templates/style.txt
Normal file
19
textual_inversion_templates/style.txt
Normal file
|
@ -0,0 +1,19 @@
|
|||
a painting, art by [name]
|
||||
a rendering, art by [name]
|
||||
a cropped painting, art by [name]
|
||||
the painting, art by [name]
|
||||
a clean painting, art by [name]
|
||||
a dirty painting, art by [name]
|
||||
a dark painting, art by [name]
|
||||
a picture, art by [name]
|
||||
a cool painting, art by [name]
|
||||
a close-up painting, art by [name]
|
||||
a bright painting, art by [name]
|
||||
a cropped painting, art by [name]
|
||||
a good painting, art by [name]
|
||||
a close-up painting, art by [name]
|
||||
a rendition, art by [name]
|
||||
a nice painting, art by [name]
|
||||
a small painting, art by [name]
|
||||
a weird painting, art by [name]
|
||||
a large painting, art by [name]
|
19
textual_inversion_templates/style_filewords.txt
Normal file
19
textual_inversion_templates/style_filewords.txt
Normal file
|
@ -0,0 +1,19 @@
|
|||
a painting of [filewords], art by [name]
|
||||
a rendering of [filewords], art by [name]
|
||||
a cropped painting of [filewords], art by [name]
|
||||
the painting of [filewords], art by [name]
|
||||
a clean painting of [filewords], art by [name]
|
||||
a dirty painting of [filewords], art by [name]
|
||||
a dark painting of [filewords], art by [name]
|
||||
a picture of [filewords], art by [name]
|
||||
a cool painting of [filewords], art by [name]
|
||||
a close-up painting of [filewords], art by [name]
|
||||
a bright painting of [filewords], art by [name]
|
||||
a cropped painting of [filewords], art by [name]
|
||||
a good painting of [filewords], art by [name]
|
||||
a close-up painting of [filewords], art by [name]
|
||||
a rendition of [filewords], art by [name]
|
||||
a nice painting of [filewords], art by [name]
|
||||
a small painting of [filewords], art by [name]
|
||||
a weird painting of [filewords], art by [name]
|
||||
a large painting of [filewords], art by [name]
|
27
textual_inversion_templates/subject.txt
Normal file
27
textual_inversion_templates/subject.txt
Normal file
|
@ -0,0 +1,27 @@
|
|||
a photo of a [name]
|
||||
a rendering of a [name]
|
||||
a cropped photo of the [name]
|
||||
the photo of a [name]
|
||||
a photo of a clean [name]
|
||||
a photo of a dirty [name]
|
||||
a dark photo of the [name]
|
||||
a photo of my [name]
|
||||
a photo of the cool [name]
|
||||
a close-up photo of a [name]
|
||||
a bright photo of the [name]
|
||||
a cropped photo of a [name]
|
||||
a photo of the [name]
|
||||
a good photo of the [name]
|
||||
a photo of one [name]
|
||||
a close-up photo of the [name]
|
||||
a rendition of the [name]
|
||||
a photo of the clean [name]
|
||||
a rendition of a [name]
|
||||
a photo of a nice [name]
|
||||
a good photo of a [name]
|
||||
a photo of the nice [name]
|
||||
a photo of the small [name]
|
||||
a photo of the weird [name]
|
||||
a photo of the large [name]
|
||||
a photo of a cool [name]
|
||||
a photo of a small [name]
|
27
textual_inversion_templates/subject_filewords.txt
Normal file
27
textual_inversion_templates/subject_filewords.txt
Normal file
|
@ -0,0 +1,27 @@
|
|||
a photo of a [name], [filewords]
|
||||
a rendering of a [name], [filewords]
|
||||
a cropped photo of the [name], [filewords]
|
||||
the photo of a [name], [filewords]
|
||||
a photo of a clean [name], [filewords]
|
||||
a photo of a dirty [name], [filewords]
|
||||
a dark photo of the [name], [filewords]
|
||||
a photo of my [name], [filewords]
|
||||
a photo of the cool [name], [filewords]
|
||||
a close-up photo of a [name], [filewords]
|
||||
a bright photo of the [name], [filewords]
|
||||
a cropped photo of a [name], [filewords]
|
||||
a photo of the [name], [filewords]
|
||||
a good photo of the [name], [filewords]
|
||||
a photo of one [name], [filewords]
|
||||
a close-up photo of the [name], [filewords]
|
||||
a rendition of the [name], [filewords]
|
||||
a photo of the clean [name], [filewords]
|
||||
a rendition of a [name], [filewords]
|
||||
a photo of a nice [name], [filewords]
|
||||
a good photo of a [name], [filewords]
|
||||
a photo of the nice [name], [filewords]
|
||||
a photo of the small [name], [filewords]
|
||||
a photo of the weird [name], [filewords]
|
||||
a photo of the large [name], [filewords]
|
||||
a photo of a cool [name], [filewords]
|
||||
a photo of a small [name], [filewords]
|
|
@ -21,6 +21,9 @@ export COMMANDLINE_ARGS=""
|
|||
# python3 venv without trailing slash (defaults to ${install_dir}/${clone_dir}/venv)
|
||||
#venv_dir="venv"
|
||||
|
||||
# script to launch to start the app
|
||||
#export LAUNCH_SCRIPT="launch.py"
|
||||
|
||||
# install command for torch
|
||||
#export TORCH_COMMAND="pip install torch==1.12.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113"
|
||||
|
||||
|
|
59
webui.py
59
webui.py
|
@ -3,36 +3,34 @@ import threading
|
|||
|
||||
from modules import devices
|
||||
from modules.paths import script_path
|
||||
|
||||
import signal
|
||||
|
||||
from modules.shared import opts, cmd_opts, state
|
||||
import modules.shared as shared
|
||||
import modules.ui
|
||||
import threading
|
||||
import modules.paths
|
||||
import modules.codeformer_model as codeformer
|
||||
import modules.esrgan_model as esrgan
|
||||
import modules.bsrgan_model as bsrgan
|
||||
import modules.extras
|
||||
import modules.face_restoration
|
||||
import modules.gfpgan_model as gfpgan
|
||||
import modules.ldsr_model as ldsr
|
||||
import modules.lowvram
|
||||
import modules.realesrgan_model as realesrgan
|
||||
import modules.scripts
|
||||
import modules.sd_hijack
|
||||
import modules.codeformer_model
|
||||
import modules.gfpgan_model
|
||||
import modules.face_restoration
|
||||
import modules.realesrgan_model as realesrgan
|
||||
import modules.esrgan_model as esrgan
|
||||
import modules.ldsr_model as ldsr
|
||||
import modules.extras
|
||||
import modules.lowvram
|
||||
import modules.txt2img
|
||||
import modules.img2img
|
||||
import modules.swinir as swinir
|
||||
import modules.sd_models
|
||||
import modules.shared as shared
|
||||
import modules.swinir_model as swinir
|
||||
import modules.ui
|
||||
from modules import modelloader
|
||||
from modules.paths import script_path
|
||||
from modules.shared import cmd_opts
|
||||
|
||||
|
||||
modules.codeformer_model.setup_codeformer()
|
||||
modules.gfpgan_model.setup_gfpgan()
|
||||
modelloader.cleanup_models()
|
||||
modules.sd_models.setup_model(cmd_opts.ckpt_dir)
|
||||
codeformer.setup_model(cmd_opts.codeformer_models_path)
|
||||
gfpgan.setup_model(cmd_opts.gfpgan_models_path)
|
||||
shared.face_restorers.append(modules.face_restoration.FaceRestoration())
|
||||
|
||||
esrgan.load_models(cmd_opts.esrgan_models_path)
|
||||
swinir.load_models(cmd_opts.swinir_models_path)
|
||||
realesrgan.setup_realesrgan()
|
||||
ldsr.add_lsdr()
|
||||
modelloader.load_upscalers()
|
||||
queue_lock = threading.Lock()
|
||||
|
||||
|
||||
|
@ -46,7 +44,7 @@ def wrap_queued_call(func):
|
|||
return f
|
||||
|
||||
|
||||
def wrap_gradio_gpu_call(func):
|
||||
def wrap_gradio_gpu_call(func, extra_outputs=None):
|
||||
def f(*args, **kwargs):
|
||||
devices.torch_gc()
|
||||
|
||||
|
@ -58,6 +56,7 @@ def wrap_gradio_gpu_call(func):
|
|||
shared.state.current_image = None
|
||||
shared.state.current_image_sampling_step = 0
|
||||
shared.state.interrupted = False
|
||||
shared.state.textinfo = None
|
||||
|
||||
with queue_lock:
|
||||
res = func(*args, **kwargs)
|
||||
|
@ -69,7 +68,7 @@ def wrap_gradio_gpu_call(func):
|
|||
|
||||
return res
|
||||
|
||||
return modules.ui.wrap_gradio_call(f)
|
||||
return modules.ui.wrap_gradio_call(f, extra_outputs=extra_outputs)
|
||||
|
||||
|
||||
modules.scripts.load_scripts(os.path.join(script_path, "scripts"))
|
||||
|
@ -86,13 +85,7 @@ def webui():
|
|||
|
||||
signal.signal(signal.SIGINT, sigint_handler)
|
||||
|
||||
demo = modules.ui.create_ui(
|
||||
txt2img=wrap_gradio_gpu_call(modules.txt2img.txt2img),
|
||||
img2img=wrap_gradio_gpu_call(modules.img2img.img2img),
|
||||
run_extras=wrap_gradio_gpu_call(modules.extras.run_extras),
|
||||
run_pnginfo=modules.extras.run_pnginfo,
|
||||
run_modelmerger=modules.extras.run_modelmerger
|
||||
)
|
||||
demo = modules.ui.create_ui(wrap_gradio_gpu_call=wrap_gradio_gpu_call)
|
||||
|
||||
demo.launch(
|
||||
share=cmd_opts.share,
|
||||
|
|
7
webui.sh
7
webui.sh
|
@ -41,6 +41,11 @@ then
|
|||
venv_dir="venv"
|
||||
fi
|
||||
|
||||
if [[ -z "${LAUNCH_SCRIPT}" ]]
|
||||
then
|
||||
LAUNCH_SCRIPT="launch.py"
|
||||
fi
|
||||
|
||||
# Disable sentry logging
|
||||
export ERROR_REPORTING=FALSE
|
||||
|
||||
|
@ -133,4 +138,4 @@ fi
|
|||
printf "\n%s\n" "${delimiter}"
|
||||
printf "Launching launch.py..."
|
||||
printf "\n%s\n" "${delimiter}"
|
||||
"${python_cmd}" launch.py
|
||||
"${python_cmd}" "${LAUNCH_SCRIPT}"
|
||||
|
|
Loading…
Reference in New Issue
Block a user