stable-diffusion-webui/modules/processing.py
2022-09-17 22:09:52 -07:00

546 lines
22 KiB
Python

import contextlib
import json
import math
import os
import sys
import torch
import numpy as np
from PIL import Image, ImageFilter, ImageOps
import random
import cv2
from skimage import exposure
import modules.sd_hijack
from modules import devices, prompt_parser
from modules.sd_hijack import model_hijack
from modules.sd_samplers import samplers, samplers_for_img2img
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
import modules.face_restoration
import modules.images as images
import modules.styles
# some of those options should not be changed at all because they would break the model, so I removed them from options.
opt_C = 4
opt_f = 8
def setup_color_correction(image):
correction_target = cv2.cvtColor(np.asarray(image.copy()), cv2.COLOR_RGB2LAB)
return correction_target
def apply_color_correction(correction, image):
image = Image.fromarray(cv2.cvtColor(exposure.match_histograms(
cv2.cvtColor(
np.asarray(image),
cv2.COLOR_RGB2LAB
),
correction,
channel_axis=2
), cv2.COLOR_LAB2RGB).astype("uint8"))
return image
class StableDiffusionProcessing:
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", styles=None, seed=-1, subseed=-1, subseed_strength=0, seed_resize_from_h=-1, seed_resize_from_w=-1, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, restore_faces=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None):
self.sd_model = sd_model
self.outpath_samples: str = outpath_samples
self.outpath_grids: str = outpath_grids
self.prompt: str = prompt
self.prompt_for_display: str = None
self.negative_prompt: str = (negative_prompt or "")
self.styles: str = styles
self.seed: int = seed
self.subseed: int = subseed
self.subseed_strength: float = subseed_strength
self.seed_resize_from_h: int = seed_resize_from_h
self.seed_resize_from_w: int = seed_resize_from_w
self.sampler_index: int = sampler_index
self.batch_size: int = batch_size
self.n_iter: int = n_iter
self.steps: int = steps
self.cfg_scale: float = cfg_scale
self.width: int = width
self.height: int = height
self.restore_faces: bool = restore_faces
self.tiling: bool = tiling
self.do_not_save_samples: bool = do_not_save_samples
self.do_not_save_grid: bool = do_not_save_grid
self.extra_generation_params: dict = extra_generation_params
self.overlay_images = overlay_images
self.paste_to = None
self.color_corrections = None
def init(self, seed):
pass
def sample(self, x, conditioning, unconditional_conditioning):
raise NotImplementedError()
class Processed:
def __init__(self, p: StableDiffusionProcessing, images_list, seed, info):
self.images = images_list
self.prompt = p.prompt
self.negative_prompt = p.negative_prompt
self.seed = seed
self.info = info
self.width = p.width
self.height = p.height
self.sampler = samplers[p.sampler_index].name
self.cfg_scale = p.cfg_scale
self.steps = p.steps
def js(self):
obj = {
"prompt": self.prompt if type(self.prompt) != list else self.prompt[0],
"negative_prompt": self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0],
"seed": int(self.seed if type(self.seed) != list else self.seed[0]),
"width": self.width,
"height": self.height,
"sampler": self.sampler,
"cfg_scale": self.cfg_scale,
"steps": self.steps,
}
return json.dumps(obj)
# from https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/3
def slerp(val, low, high):
low_norm = low/torch.norm(low, dim=1, keepdim=True)
high_norm = high/torch.norm(high, dim=1, keepdim=True)
omega = torch.acos((low_norm*high_norm).sum(1))
so = torch.sin(omega)
res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high
return res
def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0, p=None):
xs = []
# if we have multiple seeds, this means we are working with batch size>1; this then
# enables the generation of additional tensors with noise that the sampler will use during its processing.
# Using those pre-generated tensors instead of simple torch.randn allows a batch with seeds [100, 101] to
# produce the same images as with two batches [100], [101].
if p is not None and p.sampler is not None and len(seeds) > 1 and opts.enable_batch_seeds:
sampler_noises = [[] for _ in range(p.sampler.number_of_needed_noises(p))]
else:
sampler_noises = None
for i, seed in enumerate(seeds):
noise_shape = shape if seed_resize_from_h <= 0 or seed_resize_from_w <= 0 else (shape[0], seed_resize_from_h//8, seed_resize_from_w//8)
subnoise = None
if subseeds is not None:
subseed = 0 if i >= len(subseeds) else subseeds[i]
subnoise = devices.randn(subseed, noise_shape)
# randn results depend on device; gpu and cpu get different results for same seed;
# the way I see it, it's better to do this on CPU, so that everyone gets same result;
# but the original script had it like this, so I do not dare change it for now because
# it will break everyone's seeds.
noise = devices.randn(seed, noise_shape)
if subnoise is not None:
#noise = subnoise * subseed_strength + noise * (1 - subseed_strength)
noise = slerp(subseed_strength, noise, subnoise)
if noise_shape != shape:
#noise = torch.nn.functional.interpolate(noise.unsqueeze(1), size=shape[1:], mode="bilinear").squeeze()
x = devices.randn(seed, shape)
dx = (shape[2] - noise_shape[2]) // 2
dy = (shape[1] - noise_shape[1]) // 2
w = noise_shape[2] if dx >= 0 else noise_shape[2] + 2 * dx
h = noise_shape[1] if dy >= 0 else noise_shape[1] + 2 * dy
tx = 0 if dx < 0 else dx
ty = 0 if dy < 0 else dy
dx = max(-dx, 0)
dy = max(-dy, 0)
x[:, ty:ty+h, tx:tx+w] = noise[:, dy:dy+h, dx:dx+w]
noise = x
if sampler_noises is not None:
cnt = p.sampler.number_of_needed_noises(p)
for j in range(cnt):
sampler_noises[j].append(devices.randn_without_seed(tuple(noise_shape)))
xs.append(noise)
if sampler_noises is not None:
p.sampler.sampler_noises = [torch.stack(n).to(shared.device) for n in sampler_noises]
x = torch.stack(xs).to(shared.device)
return x
def fix_seed(p):
p.seed = int(random.randrange(4294967294)) if p.seed is None or p.seed == -1 else p.seed
p.subseed = int(random.randrange(4294967294)) if p.subseed is None or p.subseed == -1 else p.subseed
def process_images(p: StableDiffusionProcessing) -> Processed:
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
if type(p.prompt) == list:
assert(len(p.prompt) > 0)
else:
assert p.prompt is not None
devices.torch_gc()
fix_seed(p)
os.makedirs(p.outpath_samples, exist_ok=True)
os.makedirs(p.outpath_grids, exist_ok=True)
modules.sd_hijack.model_hijack.apply_circular(p.tiling)
comments = {}
shared.prompt_styles.apply_styles(p)
if type(p.prompt) == list:
all_prompts = p.prompt
else:
all_prompts = p.batch_size * p.n_iter * [p.prompt]
if type(p.seed) == list:
all_seeds = p.seed
else:
all_seeds = [int(p.seed + (x if p.subseed_strength == 0 else 0)) for x in range(len(all_prompts))]
if type(p.subseed) == list:
all_subseeds = p.subseed
else:
all_subseeds = [int(p.subseed + x) for x in range(len(all_prompts))]
def infotext(iteration=0, position_in_batch=0):
index = position_in_batch + iteration * p.batch_size
generation_params = {
"Steps": p.steps,
"Sampler": samplers[p.sampler_index].name,
"CFG scale": p.cfg_scale,
"Seed": all_seeds[index],
"Face restoration": (opts.face_restoration_model if p.restore_faces else None),
"Size": f"{p.width}x{p.height}",
"Model hash": (None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
"Batch size": (None if p.batch_size < 2 else p.batch_size),
"Batch pos": (None if p.batch_size < 2 else position_in_batch),
"Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
"Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
"Denoising strength": getattr(p, 'denoising_strength', None),
}
if p.extra_generation_params is not None:
generation_params.update(p.extra_generation_params)
generation_params_text = ", ".join([k if k == v else f'{k}: {v}' for k, v in generation_params.items() if v is not None])
negative_prompt_text = "\nNegative prompt: " + p.negative_prompt if p.negative_prompt else ""
return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip() + "".join(["\n\n" + x for x in comments])
if os.path.exists(cmd_opts.embeddings_dir):
model_hijack.load_textual_inversion_embeddings(cmd_opts.embeddings_dir, p.sd_model)
output_images = []
precision_scope = torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext
ema_scope = (contextlib.nullcontext if cmd_opts.lowvram else p.sd_model.ema_scope)
with torch.no_grad(), precision_scope("cuda"), ema_scope():
p.init(seed=all_seeds[0])
if state.job_count == -1:
state.job_count = p.n_iter
for n in range(p.n_iter):
if state.interrupted:
break
prompts = all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
seeds = all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
subseeds = all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
if (len(prompts) == 0):
break
#uc = p.sd_model.get_learned_conditioning(len(prompts) * [p.negative_prompt])
#c = p.sd_model.get_learned_conditioning(prompts)
uc = prompt_parser.get_learned_conditioning(len(prompts) * [p.negative_prompt], p.steps)
c = prompt_parser.get_learned_conditioning(prompts, p.steps)
if len(model_hijack.comments) > 0:
for comment in model_hijack.comments:
comments[comment] = 1
# we manually generate all input noises because each one should have a specific seed
x = create_random_tensors([opt_C, p.height // opt_f, p.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, seed_resize_from_h=p.seed_resize_from_h, seed_resize_from_w=p.seed_resize_from_w, p=p)
if p.n_iter > 1:
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
samples_ddim = p.sample(x=x, conditioning=c, unconditional_conditioning=uc)
if state.interrupted:
# if we are interruped, sample returns just noise
# use the image collected previously in sampler loop
samples_ddim = shared.state.current_latent
x_samples_ddim = p.sd_model.decode_first_stage(samples_ddim)
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
if opts.filter_nsfw:
import modules.safety as safety
x_samples_ddim = modules.safety.censor_batch(x_samples_ddim)
for i, x_sample in enumerate(x_samples_ddim):
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
x_sample = x_sample.astype(np.uint8)
if p.restore_faces:
if opts.save and not p.do_not_save_samples and opts.save_images_before_face_restoration:
images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p)
devices.torch_gc()
x_sample = modules.face_restoration.restore_faces(x_sample)
image = Image.fromarray(x_sample)
if p.color_corrections is not None and i < len(p.color_corrections):
image = apply_color_correction(p.color_corrections[i], image)
if p.overlay_images is not None and i < len(p.overlay_images):
overlay = p.overlay_images[i]
if p.paste_to is not None:
x, y, w, h = p.paste_to
base_image = Image.new('RGBA', (overlay.width, overlay.height))
image = images.resize_image(1, image, w, h)
base_image.paste(image, (x, y))
image = base_image
image = image.convert('RGBA')
image.alpha_composite(overlay)
image = image.convert('RGB')
if opts.samples_save and not p.do_not_save_samples:
images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p)
output_images.append(image)
state.nextjob()
p.color_corrections = None
unwanted_grid_because_of_img_count = len(output_images) < 2 and opts.grid_only_if_multiple
if (opts.return_grid or opts.grid_save) and not p.do_not_save_grid and not unwanted_grid_because_of_img_count:
grid = images.image_grid(output_images, p.batch_size)
if opts.return_grid:
output_images.insert(0, grid)
if opts.grid_save:
images.save_image(grid, p.outpath_grids, "grid", all_seeds[0], all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p)
devices.torch_gc()
return Processed(p, output_images, all_seeds[0], infotext())
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
sampler = None
def init(self, seed):
self.sampler = samplers[self.sampler_index].constructor(self.sd_model)
def sample(self, x, conditioning, unconditional_conditioning):
samples_ddim = self.sampler.sample(self, x, conditioning, unconditional_conditioning)
return samples_ddim
def get_crop_region(mask, pad=0):
h, w = mask.shape
crop_left = 0
for i in range(w):
if not (mask[:, i] == 0).all():
break
crop_left += 1
crop_right = 0
for i in reversed(range(w)):
if not (mask[:, i] == 0).all():
break
crop_right += 1
crop_top = 0
for i in range(h):
if not (mask[i] == 0).all():
break
crop_top += 1
crop_bottom = 0
for i in reversed(range(h)):
if not (mask[i] == 0).all():
break
crop_bottom += 1
return (
int(max(crop_left-pad, 0)),
int(max(crop_top-pad, 0)),
int(min(w - crop_right + pad, w)),
int(min(h - crop_bottom + pad, h))
)
def fill(image, mask):
image_mod = Image.new('RGBA', (image.width, image.height))
image_masked = Image.new('RGBa', (image.width, image.height))
image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(mask.convert('L')))
image_masked = image_masked.convert('RGBa')
for radius, repeats in [(256, 1), (64, 1), (16, 2), (4, 4), (2, 2), (0, 1)]:
blurred = image_masked.filter(ImageFilter.GaussianBlur(radius)).convert('RGBA')
for _ in range(repeats):
image_mod.alpha_composite(blurred)
return image_mod.convert("RGB")
class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
sampler = None
def __init__(self, init_images=None, resize_mode=0, denoising_strength=0.75, mask=None, mask_blur=4, inpainting_fill=0, inpaint_full_res=True, inpainting_mask_invert=0, **kwargs):
super().__init__(**kwargs)
self.init_images = init_images
self.resize_mode: int = resize_mode
self.denoising_strength: float = denoising_strength
self.init_latent = None
self.image_mask = mask
#self.image_unblurred_mask = None
self.latent_mask = None
self.mask_for_overlay = None
self.mask_blur = mask_blur
self.inpainting_fill = inpainting_fill
self.inpaint_full_res = inpaint_full_res
self.inpainting_mask_invert = inpainting_mask_invert
self.mask = None
self.nmask = None
def init(self, seed):
self.sampler = samplers_for_img2img[self.sampler_index].constructor(self.sd_model)
crop_region = None
if self.image_mask is not None:
self.image_mask = self.image_mask.convert('L')
if self.inpainting_mask_invert:
self.image_mask = ImageOps.invert(self.image_mask)
#self.image_unblurred_mask = self.image_mask
if self.mask_blur > 0:
self.image_mask = self.image_mask.filter(ImageFilter.GaussianBlur(self.mask_blur))
if self.inpaint_full_res:
self.mask_for_overlay = self.image_mask
mask = self.image_mask.convert('L')
crop_region = get_crop_region(np.array(mask), opts.upscale_at_full_resolution_padding)
x1, y1, x2, y2 = crop_region
mask = mask.crop(crop_region)
self.image_mask = images.resize_image(2, mask, self.width, self.height)
self.paste_to = (x1, y1, x2-x1, y2-y1)
else:
self.image_mask = images.resize_image(self.resize_mode, self.image_mask, self.width, self.height)
np_mask = np.array(self.image_mask)
np_mask = np.clip((np_mask.astype(np.float32)) * 2, 0, 255).astype(np.uint8)
self.mask_for_overlay = Image.fromarray(np_mask)
self.overlay_images = []
latent_mask = self.latent_mask if self.latent_mask is not None else self.image_mask
add_color_corrections = opts.img2img_color_correction and self.color_corrections is None
if add_color_corrections:
self.color_corrections = []
imgs = []
for img in self.init_images:
image = img.convert("RGB")
if crop_region is None:
image = images.resize_image(self.resize_mode, image, self.width, self.height)
if self.image_mask is not None:
image_masked = Image.new('RGBa', (image.width, image.height))
image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(self.mask_for_overlay.convert('L')))
self.overlay_images.append(image_masked.convert('RGBA'))
if crop_region is not None:
image = image.crop(crop_region)
image = images.resize_image(2, image, self.width, self.height)
if self.image_mask is not None:
if self.inpainting_fill != 1:
image = fill(image, latent_mask)
if add_color_corrections:
self.color_corrections.append(setup_color_correction(image))
image = np.array(image).astype(np.float32) / 255.0
image = np.moveaxis(image, 2, 0)
imgs.append(image)
if len(imgs) == 1:
batch_images = np.expand_dims(imgs[0], axis=0).repeat(self.batch_size, axis=0)
if self.overlay_images is not None:
self.overlay_images = self.overlay_images * self.batch_size
elif len(imgs) <= self.batch_size:
self.batch_size = len(imgs)
batch_images = np.array(imgs)
else:
raise RuntimeError(f"bad number of images passed: {len(imgs)}; expecting {self.batch_size} or less")
image = torch.from_numpy(batch_images)
image = 2. * image - 1.
image = image.to(shared.device)
self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image))
if self.image_mask is not None:
init_mask = latent_mask
latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2]))
latmask = np.moveaxis(np.array(latmask, dtype=np.float32), 2, 0) / 255
latmask = latmask[0]
latmask = np.around(latmask)
latmask = np.tile(latmask[None], (4, 1, 1))
self.mask = torch.asarray(1.0 - latmask).to(shared.device).type(self.sd_model.dtype)
self.nmask = torch.asarray(latmask).to(shared.device).type(self.sd_model.dtype)
if self.inpainting_fill == 2:
self.init_latent = self.init_latent * self.mask + create_random_tensors(self.init_latent.shape[1:], [seed + x + 1 for x in range(self.init_latent.shape[0])]) * self.nmask
elif self.inpainting_fill == 3:
self.init_latent = self.init_latent * self.mask
def sample(self, x, conditioning, unconditional_conditioning):
samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning)
if self.mask is not None:
samples = samples * self.nmask + self.init_latent * self.mask
return samples