Merge branch 'master' into test
This commit is contained in:
commit
cc6f8a9c9a
|
@ -1,22 +1,17 @@
|
||||||
import sys, os, shlex
|
import sys
|
||||||
import contextlib
|
import contextlib
|
||||||
import torch
|
import torch
|
||||||
from modules import errors
|
from modules import errors
|
||||||
from modules.sd_hijack_utils import CondFunc
|
|
||||||
from packaging import version
|
if sys.platform == "darwin":
|
||||||
|
from modules import mac_specific
|
||||||
|
|
||||||
|
|
||||||
# has_mps is only available in nightly pytorch (for now) and macOS 12.3+.
|
|
||||||
# check `getattr` and try it for compatibility
|
|
||||||
def has_mps() -> bool:
|
def has_mps() -> bool:
|
||||||
if not getattr(torch, 'has_mps', False):
|
if sys.platform != "darwin":
|
||||||
return False
|
return False
|
||||||
try:
|
else:
|
||||||
torch.zeros(1).to(torch.device("mps"))
|
return mac_specific.has_mps
|
||||||
return True
|
|
||||||
except Exception:
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def extract_device_id(args, name):
|
def extract_device_id(args, name):
|
||||||
for x in range(len(args)):
|
for x in range(len(args)):
|
||||||
|
@ -155,36 +150,3 @@ def test_for_nans(x, where):
|
||||||
message += " Use --disable-nan-check commandline argument to disable this check."
|
message += " Use --disable-nan-check commandline argument to disable this check."
|
||||||
|
|
||||||
raise NansException(message)
|
raise NansException(message)
|
||||||
|
|
||||||
|
|
||||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/89784
|
|
||||||
def cumsum_fix(input, cumsum_func, *args, **kwargs):
|
|
||||||
if input.device.type == 'mps':
|
|
||||||
output_dtype = kwargs.get('dtype', input.dtype)
|
|
||||||
if output_dtype == torch.int64:
|
|
||||||
return cumsum_func(input.cpu(), *args, **kwargs).to(input.device)
|
|
||||||
elif cumsum_needs_bool_fix and output_dtype == torch.bool or cumsum_needs_int_fix and (output_dtype == torch.int8 or output_dtype == torch.int16):
|
|
||||||
return cumsum_func(input.to(torch.int32), *args, **kwargs).to(torch.int64)
|
|
||||||
return cumsum_func(input, *args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
if has_mps():
|
|
||||||
if version.parse(torch.__version__) < version.parse("1.13"):
|
|
||||||
# PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working
|
|
||||||
|
|
||||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/79383
|
|
||||||
CondFunc('torch.Tensor.to', lambda orig_func, self, *args, **kwargs: orig_func(self.contiguous(), *args, **kwargs),
|
|
||||||
lambda _, self, *args, **kwargs: self.device.type != 'mps' and (args and isinstance(args[0], torch.device) and args[0].type == 'mps' or isinstance(kwargs.get('device'), torch.device) and kwargs['device'].type == 'mps'))
|
|
||||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/80800
|
|
||||||
CondFunc('torch.nn.functional.layer_norm', lambda orig_func, *args, **kwargs: orig_func(*([args[0].contiguous()] + list(args[1:])), **kwargs),
|
|
||||||
lambda _, *args, **kwargs: args and isinstance(args[0], torch.Tensor) and args[0].device.type == 'mps')
|
|
||||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/90532
|
|
||||||
CondFunc('torch.Tensor.numpy', lambda orig_func, self, *args, **kwargs: orig_func(self.detach(), *args, **kwargs), lambda _, self, *args, **kwargs: self.requires_grad)
|
|
||||||
elif version.parse(torch.__version__) > version.parse("1.13.1"):
|
|
||||||
cumsum_needs_int_fix = not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.ShortTensor([1,1]).to(torch.device("mps")).cumsum(0))
|
|
||||||
cumsum_needs_bool_fix = not torch.BoolTensor([True,True]).to(device=torch.device("mps"), dtype=torch.int64).equal(torch.BoolTensor([True,False]).to(torch.device("mps")).cumsum(0))
|
|
||||||
cumsum_fix_func = lambda orig_func, input, *args, **kwargs: cumsum_fix(input, orig_func, *args, **kwargs)
|
|
||||||
CondFunc('torch.cumsum', cumsum_fix_func, None)
|
|
||||||
CondFunc('torch.Tensor.cumsum', cumsum_fix_func, None)
|
|
||||||
CondFunc('torch.narrow', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).clone(), None)
|
|
||||||
|
|
||||||
|
|
|
@ -4,6 +4,7 @@ import os.path
|
||||||
|
|
||||||
import filelock
|
import filelock
|
||||||
|
|
||||||
|
from modules import shared
|
||||||
from modules.paths import data_path
|
from modules.paths import data_path
|
||||||
|
|
||||||
|
|
||||||
|
@ -68,6 +69,9 @@ def sha256(filename, title):
|
||||||
if sha256_value is not None:
|
if sha256_value is not None:
|
||||||
return sha256_value
|
return sha256_value
|
||||||
|
|
||||||
|
if shared.cmd_opts.no_hashing:
|
||||||
|
return None
|
||||||
|
|
||||||
print(f"Calculating sha256 for {filename}: ", end='')
|
print(f"Calculating sha256 for {filename}: ", end='')
|
||||||
sha256_value = calculate_sha256(filename)
|
sha256_value = calculate_sha256(filename)
|
||||||
print(f"{sha256_value}")
|
print(f"{sha256_value}")
|
||||||
|
|
|
@ -307,7 +307,7 @@ class Hypernetwork:
|
||||||
def shorthash(self):
|
def shorthash(self):
|
||||||
sha256 = hashes.sha256(self.filename, f'hypernet/{self.name}')
|
sha256 = hashes.sha256(self.filename, f'hypernet/{self.name}')
|
||||||
|
|
||||||
return sha256[0:10]
|
return sha256[0:10] if sha256 else None
|
||||||
|
|
||||||
|
|
||||||
def list_hypernetworks(path):
|
def list_hypernetworks(path):
|
||||||
|
|
|
@ -76,7 +76,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args):
|
||||||
processed_image.save(os.path.join(output_dir, filename))
|
processed_image.save(os.path.join(output_dir, filename))
|
||||||
|
|
||||||
|
|
||||||
def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, *args):
|
def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, *args):
|
||||||
override_settings = create_override_settings_dict(override_settings_texts)
|
override_settings = create_override_settings_dict(override_settings_texts)
|
||||||
|
|
||||||
is_batch = mode == 5
|
is_batch = mode == 5
|
||||||
|
@ -142,6 +142,7 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
|
||||||
inpainting_fill=inpainting_fill,
|
inpainting_fill=inpainting_fill,
|
||||||
resize_mode=resize_mode,
|
resize_mode=resize_mode,
|
||||||
denoising_strength=denoising_strength,
|
denoising_strength=denoising_strength,
|
||||||
|
image_cfg_scale=image_cfg_scale,
|
||||||
inpaint_full_res=inpaint_full_res,
|
inpaint_full_res=inpaint_full_res,
|
||||||
inpaint_full_res_padding=inpaint_full_res_padding,
|
inpaint_full_res_padding=inpaint_full_res_padding,
|
||||||
inpainting_mask_invert=inpainting_mask_invert,
|
inpainting_mask_invert=inpainting_mask_invert,
|
||||||
|
|
53
modules/mac_specific.py
Normal file
53
modules/mac_specific.py
Normal file
|
@ -0,0 +1,53 @@
|
||||||
|
import torch
|
||||||
|
from modules import paths
|
||||||
|
from modules.sd_hijack_utils import CondFunc
|
||||||
|
from packaging import version
|
||||||
|
|
||||||
|
|
||||||
|
# has_mps is only available in nightly pytorch (for now) and macOS 12.3+.
|
||||||
|
# check `getattr` and try it for compatibility
|
||||||
|
def check_for_mps() -> bool:
|
||||||
|
if not getattr(torch, 'has_mps', False):
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
torch.zeros(1).to(torch.device("mps"))
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
has_mps = check_for_mps()
|
||||||
|
|
||||||
|
|
||||||
|
# MPS workaround for https://github.com/pytorch/pytorch/issues/89784
|
||||||
|
def cumsum_fix(input, cumsum_func, *args, **kwargs):
|
||||||
|
if input.device.type == 'mps':
|
||||||
|
output_dtype = kwargs.get('dtype', input.dtype)
|
||||||
|
if output_dtype == torch.int64:
|
||||||
|
return cumsum_func(input.cpu(), *args, **kwargs).to(input.device)
|
||||||
|
elif cumsum_needs_bool_fix and output_dtype == torch.bool or cumsum_needs_int_fix and (output_dtype == torch.int8 or output_dtype == torch.int16):
|
||||||
|
return cumsum_func(input.to(torch.int32), *args, **kwargs).to(torch.int64)
|
||||||
|
return cumsum_func(input, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
if has_mps:
|
||||||
|
# MPS fix for randn in torchsde
|
||||||
|
CondFunc('torchsde._brownian.brownian_interval._randn', lambda _, size, dtype, device, seed: torch.randn(size, dtype=dtype, device=torch.device("cpu"), generator=torch.Generator(torch.device("cpu")).manual_seed(int(seed))).to(device), lambda _, size, dtype, device, seed: device.type == 'mps')
|
||||||
|
|
||||||
|
if version.parse(torch.__version__) < version.parse("1.13"):
|
||||||
|
# PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working
|
||||||
|
|
||||||
|
# MPS workaround for https://github.com/pytorch/pytorch/issues/79383
|
||||||
|
CondFunc('torch.Tensor.to', lambda orig_func, self, *args, **kwargs: orig_func(self.contiguous(), *args, **kwargs),
|
||||||
|
lambda _, self, *args, **kwargs: self.device.type != 'mps' and (args and isinstance(args[0], torch.device) and args[0].type == 'mps' or isinstance(kwargs.get('device'), torch.device) and kwargs['device'].type == 'mps'))
|
||||||
|
# MPS workaround for https://github.com/pytorch/pytorch/issues/80800
|
||||||
|
CondFunc('torch.nn.functional.layer_norm', lambda orig_func, *args, **kwargs: orig_func(*([args[0].contiguous()] + list(args[1:])), **kwargs),
|
||||||
|
lambda _, *args, **kwargs: args and isinstance(args[0], torch.Tensor) and args[0].device.type == 'mps')
|
||||||
|
# MPS workaround for https://github.com/pytorch/pytorch/issues/90532
|
||||||
|
CondFunc('torch.Tensor.numpy', lambda orig_func, self, *args, **kwargs: orig_func(self.detach(), *args, **kwargs), lambda _, self, *args, **kwargs: self.requires_grad)
|
||||||
|
elif version.parse(torch.__version__) > version.parse("1.13.1"):
|
||||||
|
cumsum_needs_int_fix = not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.ShortTensor([1,1]).to(torch.device("mps")).cumsum(0))
|
||||||
|
cumsum_needs_bool_fix = not torch.BoolTensor([True,True]).to(device=torch.device("mps"), dtype=torch.int64).equal(torch.BoolTensor([True,False]).to(torch.device("mps")).cumsum(0))
|
||||||
|
cumsum_fix_func = lambda orig_func, input, *args, **kwargs: cumsum_fix(input, orig_func, *args, **kwargs)
|
||||||
|
CondFunc('torch.cumsum', cumsum_fix_func, None)
|
||||||
|
CondFunc('torch.Tensor.cumsum', cumsum_fix_func, None)
|
||||||
|
CondFunc('torch.narrow', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).clone(), None)
|
||||||
|
|
|
@ -186,7 +186,7 @@ class StableDiffusionProcessing:
|
||||||
return conditioning
|
return conditioning
|
||||||
|
|
||||||
def edit_image_conditioning(self, source_image):
|
def edit_image_conditioning(self, source_image):
|
||||||
conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image))
|
conditioning_image = self.sd_model.encode_first_stage(source_image).mode()
|
||||||
|
|
||||||
return conditioning_image
|
return conditioning_image
|
||||||
|
|
||||||
|
@ -268,6 +268,7 @@ class Processed:
|
||||||
self.height = p.height
|
self.height = p.height
|
||||||
self.sampler_name = p.sampler_name
|
self.sampler_name = p.sampler_name
|
||||||
self.cfg_scale = p.cfg_scale
|
self.cfg_scale = p.cfg_scale
|
||||||
|
self.image_cfg_scale = getattr(p, 'image_cfg_scale', None)
|
||||||
self.steps = p.steps
|
self.steps = p.steps
|
||||||
self.batch_size = p.batch_size
|
self.batch_size = p.batch_size
|
||||||
self.restore_faces = p.restore_faces
|
self.restore_faces = p.restore_faces
|
||||||
|
@ -445,6 +446,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
|
||||||
"Steps": p.steps,
|
"Steps": p.steps,
|
||||||
"Sampler": p.sampler_name,
|
"Sampler": p.sampler_name,
|
||||||
"CFG scale": p.cfg_scale,
|
"CFG scale": p.cfg_scale,
|
||||||
|
"Image CFG scale": getattr(p, 'image_cfg_scale', None),
|
||||||
"Seed": all_seeds[index],
|
"Seed": all_seeds[index],
|
||||||
"Face restoration": (opts.face_restoration_model if p.restore_faces else None),
|
"Face restoration": (opts.face_restoration_model if p.restore_faces else None),
|
||||||
"Size": f"{p.width}x{p.height}",
|
"Size": f"{p.width}x{p.height}",
|
||||||
|
@ -901,12 +903,13 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||||
class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
||||||
sampler = None
|
sampler = None
|
||||||
|
|
||||||
def __init__(self, init_images: list = None, resize_mode: int = 0, denoising_strength: float = 0.75, mask: Any = None, mask_blur: int = 4, inpainting_fill: int = 0, inpaint_full_res: bool = True, inpaint_full_res_padding: int = 0, inpainting_mask_invert: int = 0, initial_noise_multiplier: float = None, **kwargs):
|
def __init__(self, init_images: list = None, resize_mode: int = 0, denoising_strength: float = 0.75, image_cfg_scale: float = None, mask: Any = None, mask_blur: int = 4, inpainting_fill: int = 0, inpaint_full_res: bool = True, inpaint_full_res_padding: int = 0, inpainting_mask_invert: int = 0, initial_noise_multiplier: float = None, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
self.init_images = init_images
|
self.init_images = init_images
|
||||||
self.resize_mode: int = resize_mode
|
self.resize_mode: int = resize_mode
|
||||||
self.denoising_strength: float = denoising_strength
|
self.denoising_strength: float = denoising_strength
|
||||||
|
self.image_cfg_scale: float = image_cfg_scale if shared.sd_model.cond_stage_key == "edit" else None
|
||||||
self.init_latent = None
|
self.init_latent = None
|
||||||
self.image_mask = mask
|
self.image_mask = mask
|
||||||
self.latent_mask = None
|
self.latent_mask = None
|
||||||
|
|
|
@ -59,13 +59,17 @@ class CheckpointInfo:
|
||||||
|
|
||||||
def calculate_shorthash(self):
|
def calculate_shorthash(self):
|
||||||
self.sha256 = hashes.sha256(self.filename, "checkpoint/" + self.name)
|
self.sha256 = hashes.sha256(self.filename, "checkpoint/" + self.name)
|
||||||
|
if self.sha256 is None:
|
||||||
|
return
|
||||||
|
|
||||||
self.shorthash = self.sha256[0:10]
|
self.shorthash = self.sha256[0:10]
|
||||||
|
|
||||||
if self.shorthash not in self.ids:
|
if self.shorthash not in self.ids:
|
||||||
self.ids += [self.shorthash, self.sha256]
|
self.ids += [self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]']
|
||||||
self.register()
|
|
||||||
|
|
||||||
|
checkpoints_list.pop(self.title)
|
||||||
self.title = f'{self.name} [{self.shorthash}]'
|
self.title = f'{self.name} [{self.shorthash}]'
|
||||||
|
self.register()
|
||||||
|
|
||||||
return self.shorthash
|
return self.shorthash
|
||||||
|
|
||||||
|
@ -158,7 +162,7 @@ def select_checkpoint():
|
||||||
print(f" - directory {model_path}", file=sys.stderr)
|
print(f" - directory {model_path}", file=sys.stderr)
|
||||||
if shared.cmd_opts.ckpt_dir is not None:
|
if shared.cmd_opts.ckpt_dir is not None:
|
||||||
print(f" - directory {os.path.abspath(shared.cmd_opts.ckpt_dir)}", file=sys.stderr)
|
print(f" - directory {os.path.abspath(shared.cmd_opts.ckpt_dir)}", file=sys.stderr)
|
||||||
print("Can't run without a checkpoint. Find and place a .ckpt file into any of those locations. The program will exit.", file=sys.stderr)
|
print("Can't run without a checkpoint. Find and place a .ckpt or .safetensors file into any of those locations. The program will exit.", file=sys.stderr)
|
||||||
exit(1)
|
exit(1)
|
||||||
|
|
||||||
checkpoint_info = next(iter(checkpoints_list.values()))
|
checkpoint_info = next(iter(checkpoints_list.values()))
|
||||||
|
|
|
@ -2,7 +2,6 @@ from collections import namedtuple
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import torchsde._brownian.brownian_interval
|
|
||||||
from modules import devices, processing, images, sd_vae_approx
|
from modules import devices, processing, images, sd_vae_approx
|
||||||
|
|
||||||
from modules.shared import opts, state
|
from modules.shared import opts, state
|
||||||
|
@ -61,18 +60,3 @@ def store_latent(decoded):
|
||||||
|
|
||||||
class InterruptedException(BaseException):
|
class InterruptedException(BaseException):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
# MPS fix for randn in torchsde
|
|
||||||
# XXX move this to separate file for MPS
|
|
||||||
def torchsde_randn(size, dtype, device, seed):
|
|
||||||
if device.type == 'mps':
|
|
||||||
generator = torch.Generator(devices.cpu).manual_seed(int(seed))
|
|
||||||
return torch.randn(size, dtype=dtype, device=devices.cpu, generator=generator).to(device)
|
|
||||||
else:
|
|
||||||
generator = torch.Generator(device).manual_seed(int(seed))
|
|
||||||
return torch.randn(size, dtype=dtype, device=device, generator=generator)
|
|
||||||
|
|
||||||
|
|
||||||
torchsde._brownian.brownian_interval._randn = torchsde_randn
|
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
from collections import deque
|
from collections import deque
|
||||||
import torch
|
import torch
|
||||||
import inspect
|
import inspect
|
||||||
|
import einops
|
||||||
import k_diffusion.sampling
|
import k_diffusion.sampling
|
||||||
from modules import prompt_parser, devices, sd_samplers_common
|
from modules import prompt_parser, devices, sd_samplers_common
|
||||||
|
|
||||||
|
@ -56,6 +57,7 @@ class CFGDenoiser(torch.nn.Module):
|
||||||
self.nmask = None
|
self.nmask = None
|
||||||
self.init_latent = None
|
self.init_latent = None
|
||||||
self.step = 0
|
self.step = 0
|
||||||
|
self.image_cfg_scale = None
|
||||||
|
|
||||||
def combine_denoised(self, x_out, conds_list, uncond, cond_scale):
|
def combine_denoised(self, x_out, conds_list, uncond, cond_scale):
|
||||||
denoised_uncond = x_out[-uncond.shape[0]:]
|
denoised_uncond = x_out[-uncond.shape[0]:]
|
||||||
|
@ -67,19 +69,36 @@ class CFGDenoiser(torch.nn.Module):
|
||||||
|
|
||||||
return denoised
|
return denoised
|
||||||
|
|
||||||
|
def combine_denoised_for_edit_model(self, x_out, cond_scale):
|
||||||
|
out_cond, out_img_cond, out_uncond = x_out.chunk(3)
|
||||||
|
denoised = out_uncond + cond_scale * (out_cond - out_img_cond) + self.image_cfg_scale * (out_img_cond - out_uncond)
|
||||||
|
|
||||||
|
return denoised
|
||||||
|
|
||||||
def forward(self, x, sigma, uncond, cond, cond_scale, image_cond):
|
def forward(self, x, sigma, uncond, cond, cond_scale, image_cond):
|
||||||
if state.interrupted or state.skipped:
|
if state.interrupted or state.skipped:
|
||||||
raise sd_samplers_common.InterruptedException
|
raise sd_samplers_common.InterruptedException
|
||||||
|
|
||||||
|
# at self.image_cfg_scale == 1.0 produced results for edit model are the same as with normal sampling,
|
||||||
|
# so is_edit_model is set to False to support AND composition.
|
||||||
|
is_edit_model = shared.sd_model.cond_stage_key == "edit" and self.image_cfg_scale is not None and self.image_cfg_scale != 1.0
|
||||||
|
|
||||||
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
|
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
|
||||||
uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
|
uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
|
||||||
|
|
||||||
|
assert not is_edit_model or all([len(conds) == 1 for conds in conds_list]), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"
|
||||||
|
|
||||||
batch_size = len(conds_list)
|
batch_size = len(conds_list)
|
||||||
repeats = [len(conds_list[i]) for i in range(batch_size)]
|
repeats = [len(conds_list[i]) for i in range(batch_size)]
|
||||||
|
|
||||||
|
if not is_edit_model:
|
||||||
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
|
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
|
||||||
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond])
|
|
||||||
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
|
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
|
||||||
|
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond])
|
||||||
|
else:
|
||||||
|
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x] + [x])
|
||||||
|
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma])
|
||||||
|
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond] + [torch.zeros_like(self.init_latent)])
|
||||||
|
|
||||||
denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps)
|
denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps)
|
||||||
cfg_denoiser_callback(denoiser_params)
|
cfg_denoiser_callback(denoiser_params)
|
||||||
|
@ -88,7 +107,10 @@ class CFGDenoiser(torch.nn.Module):
|
||||||
sigma_in = denoiser_params.sigma
|
sigma_in = denoiser_params.sigma
|
||||||
|
|
||||||
if tensor.shape[1] == uncond.shape[1]:
|
if tensor.shape[1] == uncond.shape[1]:
|
||||||
|
if not is_edit_model:
|
||||||
cond_in = torch.cat([tensor, uncond])
|
cond_in = torch.cat([tensor, uncond])
|
||||||
|
else:
|
||||||
|
cond_in = torch.cat([tensor, uncond, uncond])
|
||||||
|
|
||||||
if shared.batch_cond_uncond:
|
if shared.batch_cond_uncond:
|
||||||
x_out = self.inner_model(x_in, sigma_in, cond={"c_crossattn": [cond_in], "c_concat": [image_cond_in]})
|
x_out = self.inner_model(x_in, sigma_in, cond={"c_crossattn": [cond_in], "c_concat": [image_cond_in]})
|
||||||
|
@ -104,7 +126,13 @@ class CFGDenoiser(torch.nn.Module):
|
||||||
for batch_offset in range(0, tensor.shape[0], batch_size):
|
for batch_offset in range(0, tensor.shape[0], batch_size):
|
||||||
a = batch_offset
|
a = batch_offset
|
||||||
b = min(a + batch_size, tensor.shape[0])
|
b = min(a + batch_size, tensor.shape[0])
|
||||||
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": [tensor[a:b]], "c_concat": [image_cond_in[a:b]]})
|
|
||||||
|
if not is_edit_model:
|
||||||
|
c_crossattn = [tensor[a:b]]
|
||||||
|
else:
|
||||||
|
c_crossattn = torch.cat([tensor[a:b]], uncond)
|
||||||
|
|
||||||
|
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": c_crossattn, "c_concat": [image_cond_in[a:b]]})
|
||||||
|
|
||||||
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond={"c_crossattn": [uncond], "c_concat": [image_cond_in[-uncond.shape[0]:]]})
|
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond={"c_crossattn": [uncond], "c_concat": [image_cond_in[-uncond.shape[0]:]]})
|
||||||
|
|
||||||
|
@ -115,7 +143,10 @@ class CFGDenoiser(torch.nn.Module):
|
||||||
elif opts.live_preview_content == "Negative prompt":
|
elif opts.live_preview_content == "Negative prompt":
|
||||||
sd_samplers_common.store_latent(x_out[-uncond.shape[0]:])
|
sd_samplers_common.store_latent(x_out[-uncond.shape[0]:])
|
||||||
|
|
||||||
|
if not is_edit_model:
|
||||||
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
|
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
|
||||||
|
else:
|
||||||
|
denoised = self.combine_denoised_for_edit_model(x_out, cond_scale)
|
||||||
|
|
||||||
if self.mask is not None:
|
if self.mask is not None:
|
||||||
denoised = self.init_latent * self.mask + self.nmask * denoised
|
denoised = self.init_latent * self.mask + self.nmask * denoised
|
||||||
|
@ -198,6 +229,7 @@ class KDiffusionSampler:
|
||||||
self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None
|
self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None
|
||||||
self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None
|
self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None
|
||||||
self.model_wrap_cfg.step = 0
|
self.model_wrap_cfg.step = 0
|
||||||
|
self.model_wrap_cfg.image_cfg_scale = getattr(p, 'image_cfg_scale', None)
|
||||||
self.eta = p.eta if p.eta is not None else opts.eta_ancestral
|
self.eta = p.eta if p.eta is not None else opts.eta_ancestral
|
||||||
|
|
||||||
k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else [])
|
k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else [])
|
||||||
|
@ -260,13 +292,14 @@ class KDiffusionSampler:
|
||||||
|
|
||||||
self.model_wrap_cfg.init_latent = x
|
self.model_wrap_cfg.init_latent = x
|
||||||
self.last_latent = x
|
self.last_latent = x
|
||||||
|
extra_args={
|
||||||
samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args={
|
|
||||||
'cond': conditioning,
|
'cond': conditioning,
|
||||||
'image_cond': image_conditioning,
|
'image_cond': image_conditioning,
|
||||||
'uncond': unconditional_conditioning,
|
'uncond': unconditional_conditioning,
|
||||||
'cond_scale': p.cfg_scale
|
'cond_scale': p.cfg_scale,
|
||||||
}, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
}
|
||||||
|
|
||||||
|
samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
||||||
|
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
|
|
|
@ -106,7 +106,7 @@ parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, req
|
||||||
parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None)
|
parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None)
|
||||||
parser.add_argument("--gradio-queue", action='store_true', help="Uses gradio queue; experimental option; breaks restart UI button")
|
parser.add_argument("--gradio-queue", action='store_true', help="Uses gradio queue; experimental option; breaks restart UI button")
|
||||||
parser.add_argument("--skip-version-check", action='store_true', help="Do not check versions of torch and xformers")
|
parser.add_argument("--skip-version-check", action='store_true', help="Do not check versions of torch and xformers")
|
||||||
|
parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False)
|
||||||
|
|
||||||
|
|
||||||
script_loading.preload_extensions(extensions.extensions_dir, parser)
|
script_loading.preload_extensions(extensions.extensions_dir, parser)
|
||||||
|
|
|
@ -765,7 +765,9 @@ def create_ui():
|
||||||
|
|
||||||
elif category == "cfg":
|
elif category == "cfg":
|
||||||
with FormGroup():
|
with FormGroup():
|
||||||
|
with FormRow():
|
||||||
cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="img2img_cfg_scale")
|
cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="img2img_cfg_scale")
|
||||||
|
image_cfg_scale = gr.Slider(minimum=0, maximum=3.0, step=0.05, label='Image CFG Scale', value=1.5, elem_id="img2img_image_cfg_scale", visible=shared.sd_model and shared.sd_model.cond_stage_key == "edit")
|
||||||
denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75, elem_id="img2img_denoising_strength")
|
denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75, elem_id="img2img_denoising_strength")
|
||||||
|
|
||||||
elif category == "seed":
|
elif category == "seed":
|
||||||
|
@ -861,6 +863,7 @@ def create_ui():
|
||||||
batch_count,
|
batch_count,
|
||||||
batch_size,
|
batch_size,
|
||||||
cfg_scale,
|
cfg_scale,
|
||||||
|
image_cfg_scale,
|
||||||
denoising_strength,
|
denoising_strength,
|
||||||
seed,
|
seed,
|
||||||
subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox,
|
subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox,
|
||||||
|
@ -947,6 +950,7 @@ def create_ui():
|
||||||
(sampler_index, "Sampler"),
|
(sampler_index, "Sampler"),
|
||||||
(restore_faces, "Face restoration"),
|
(restore_faces, "Face restoration"),
|
||||||
(cfg_scale, "CFG scale"),
|
(cfg_scale, "CFG scale"),
|
||||||
|
(image_cfg_scale, "Image CFG scale"),
|
||||||
(seed, "Seed"),
|
(seed, "Seed"),
|
||||||
(width, "Size-1"),
|
(width, "Size-1"),
|
||||||
(height, "Size-2"),
|
(height, "Size-2"),
|
||||||
|
@ -1591,6 +1595,12 @@ def create_ui():
|
||||||
outputs=[component, text_settings],
|
outputs=[component, text_settings],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
text_settings.change(
|
||||||
|
fn=lambda: gr.update(visible=shared.sd_model and shared.sd_model.cond_stage_key == "edit"),
|
||||||
|
inputs=[],
|
||||||
|
outputs=[image_cfg_scale],
|
||||||
|
)
|
||||||
|
|
||||||
button_set_checkpoint = gr.Button('Change checkpoint', elem_id='change_checkpoint', visible=False)
|
button_set_checkpoint = gr.Button('Change checkpoint', elem_id='change_checkpoint', visible=False)
|
||||||
button_set_checkpoint.click(
|
button_set_checkpoint.click(
|
||||||
fn=lambda value, _: run_settings_single(value, key='sd_model_checkpoint'),
|
fn=lambda value, _: run_settings_single(value, key='sd_model_checkpoint'),
|
||||||
|
|
|
@ -29,8 +29,9 @@ def add_pages_to_demo(app):
|
||||||
if not any([Path(x).resolve() in Path(filename).resolve().parents for x in allowed_dirs]):
|
if not any([Path(x).resolve() in Path(filename).resolve().parents for x in allowed_dirs]):
|
||||||
raise ValueError(f"File cannot be fetched: {filename}. Must be in one of directories registered by extra pages.")
|
raise ValueError(f"File cannot be fetched: {filename}. Must be in one of directories registered by extra pages.")
|
||||||
|
|
||||||
if os.path.splitext(filename)[1].lower() != ".png":
|
ext = os.path.splitext(filename)[1].lower()
|
||||||
raise ValueError(f"File cannot be fetched: {filename}. Only png.")
|
if ext not in (".png", ".jpg"):
|
||||||
|
raise ValueError(f"File cannot be fetched: {filename}. Only png and jpg.")
|
||||||
|
|
||||||
# would profit from returning 304
|
# would profit from returning 304
|
||||||
return FileResponse(filename, headers={"Accept-Ranges": "bytes"})
|
return FileResponse(filename, headers={"Accept-Ranges": "bytes"})
|
||||||
|
|
|
@ -6,7 +6,7 @@ from tqdm import trange
|
||||||
import modules.scripts as scripts
|
import modules.scripts as scripts
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
from modules import processing, shared, sd_samplers, prompt_parser
|
from modules import processing, shared, sd_samplers, prompt_parser, sd_samplers_common
|
||||||
from modules.processing import Processed
|
from modules.processing import Processed
|
||||||
from modules.shared import opts, cmd_opts, state
|
from modules.shared import opts, cmd_opts, state
|
||||||
|
|
||||||
|
@ -50,7 +50,7 @@ def find_noise_for_image(p, cond, uncond, cfg_scale, steps):
|
||||||
|
|
||||||
x = x + d * dt
|
x = x + d * dt
|
||||||
|
|
||||||
sd_samplers.store_latent(x)
|
sd_samplers_common.store_latent(x)
|
||||||
|
|
||||||
# This shouldn't be necessary, but solved some VRAM issues
|
# This shouldn't be necessary, but solved some VRAM issues
|
||||||
del x_in, sigma_in, cond_in, c_out, c_in, t,
|
del x_in, sigma_in, cond_in, c_out, c_in, t,
|
||||||
|
@ -104,7 +104,7 @@ def find_noise_for_image_sigma_adjustment(p, cond, uncond, cfg_scale, steps):
|
||||||
dt = sigmas[i] - sigmas[i - 1]
|
dt = sigmas[i] - sigmas[i - 1]
|
||||||
x = x + d * dt
|
x = x + d * dt
|
||||||
|
|
||||||
sd_samplers.store_latent(x)
|
sd_samplers_common.store_latent(x)
|
||||||
|
|
||||||
# This shouldn't be necessary, but solved some VRAM issues
|
# This shouldn't be necessary, but solved some VRAM issues
|
||||||
del x_in, sigma_in, cond_in, c_out, c_in, t,
|
del x_in, sigma_in, cond_in, c_out, c_in, t,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user