Merge remote-tracking branch 'upstream/master' into PowerShell
This commit is contained in:
commit
010787ad9c
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -32,3 +32,4 @@ notification.mp3
|
|||
/extensions
|
||||
/test/stdout.txt
|
||||
/test/stderr.txt
|
||||
/cache.json
|
||||
|
|
|
@ -92,6 +92,7 @@ titles = {
|
|||
"Weighted sum": "Result = A * (1 - M) + B * M",
|
||||
"Add difference": "Result = A + (B - C) * M",
|
||||
|
||||
"Initialization text": "If the number of tokens is more than the number of vectors, some may be skipped.\nLeave the textbox empty to start with zeroed out vectors",
|
||||
"Learning rate": "How fast should training go. Low values will take longer to train, high values may fail to converge (not generate accurate results) and/or may break the embedding (This has happened if you see Loss: nan in the training info textbox. If this happens, you need to manually restore your embedding from an older not-broken backup).\n\nYou can set a single numeric value, or multiple learning rates using the syntax:\n\n rate_1:max_steps_1, rate_2:max_steps_2, ...\n\nEG: 0.005:100, 1e-3:1000, 1e-5\n\nWill train with rate of 0.005 for first 100 steps, then 1e-3 until 1000 steps, then 1e-5 for all remaining steps.",
|
||||
|
||||
"Clip skip": "Early stopping parameter for CLIP model; 1 is stop at last layer as usual, 2 is stop at penultimate layer, etc.",
|
||||
|
|
|
@ -151,6 +151,7 @@ function showGalleryImage() {
|
|||
e.addEventListener('mousedown', function (evt) {
|
||||
if(!opts.js_modal_lightbox || evt.button != 0) return;
|
||||
modalZoomSet(gradioApp().getElementById('modalImage'), opts.js_modal_lightbox_initially_zoomed)
|
||||
evt.preventDefault()
|
||||
showModal(evt)
|
||||
}, true);
|
||||
}
|
||||
|
|
|
@ -54,7 +54,7 @@ function switch_to_img2img(){
|
|||
|
||||
function switch_to_inpaint(){
|
||||
gradioApp().querySelector('#tabs').querySelectorAll('button')[1].click();
|
||||
gradioApp().getElementById('mode_img2img').querySelectorAll('button')[1].click();
|
||||
gradioApp().getElementById('mode_img2img').querySelectorAll('button')[2].click();
|
||||
|
||||
return args_to_array(arguments);
|
||||
}
|
||||
|
@ -143,14 +143,6 @@ function confirm_clear_prompt(prompt, negative_prompt) {
|
|||
|
||||
|
||||
opts = {}
|
||||
function apply_settings(jsdata){
|
||||
console.log(jsdata)
|
||||
|
||||
opts = JSON.parse(jsdata)
|
||||
|
||||
return jsdata
|
||||
}
|
||||
|
||||
onUiUpdate(function(){
|
||||
if(Object.keys(opts).length != 0) return;
|
||||
|
||||
|
@ -160,7 +152,7 @@ onUiUpdate(function(){
|
|||
textarea = json_elem.querySelector('textarea')
|
||||
jsdata = textarea.value
|
||||
opts = JSON.parse(jsdata)
|
||||
|
||||
executeCallbacks(optionsChangedCallbacks);
|
||||
|
||||
Object.defineProperty(textarea, 'value', {
|
||||
set: function(newValue) {
|
||||
|
@ -171,6 +163,8 @@ onUiUpdate(function(){
|
|||
if (oldValue != newValue) {
|
||||
opts = JSON.parse(textarea.value)
|
||||
}
|
||||
|
||||
executeCallbacks(optionsChangedCallbacks);
|
||||
},
|
||||
get: function() {
|
||||
var valueProp = Object.getOwnPropertyDescriptor(HTMLTextAreaElement.prototype, 'value');
|
||||
|
@ -201,6 +195,19 @@ onUiUpdate(function(){
|
|||
}
|
||||
})
|
||||
|
||||
|
||||
onOptionsChanged(function(){
|
||||
elem = gradioApp().getElementById('sd_checkpoint_hash')
|
||||
sd_checkpoint_hash = opts.sd_checkpoint_hash || ""
|
||||
shorthash = sd_checkpoint_hash.substr(0,10)
|
||||
|
||||
if(elem && elem.textContent != shorthash){
|
||||
elem.textContent = shorthash
|
||||
elem.title = sd_checkpoint_hash
|
||||
elem.href = "https://google.com/search?q=" + sd_checkpoint_hash
|
||||
}
|
||||
})
|
||||
|
||||
let txt2img_textarea, img2img_textarea = undefined;
|
||||
let wait_time = 800
|
||||
let token_timeout;
|
||||
|
|
|
@ -286,7 +286,7 @@ class Api:
|
|||
# copy from check_progress_call of ui.py
|
||||
|
||||
if shared.state.job_count == 0:
|
||||
return ProgressResponse(progress=0, eta_relative=0, state=shared.state.dict())
|
||||
return ProgressResponse(progress=0, eta_relative=0, state=shared.state.dict(), textinfo=shared.state.textinfo)
|
||||
|
||||
# avoid dividing zero
|
||||
progress = 0.01
|
||||
|
@ -308,7 +308,7 @@ class Api:
|
|||
if shared.state.current_image and not req.skip_current_image:
|
||||
current_image = encode_pil_to_base64(shared.state.current_image)
|
||||
|
||||
return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image)
|
||||
return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image, textinfo=shared.state.textinfo)
|
||||
|
||||
def interrogateapi(self, interrogatereq: InterrogateRequest):
|
||||
image_b64 = interrogatereq.image
|
||||
|
@ -371,7 +371,7 @@ class Api:
|
|||
return upscalers
|
||||
|
||||
def get_sd_models(self):
|
||||
return [{"title":x.title, "model_name":x.model_name, "hash":x.hash, "filename": x.filename, "config": find_checkpoint_config(x)} for x in checkpoints_list.values()]
|
||||
return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config(x)} for x in checkpoints_list.values()]
|
||||
|
||||
def get_hypernetworks(self):
|
||||
return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks]
|
||||
|
|
|
@ -168,6 +168,7 @@ class ProgressResponse(BaseModel):
|
|||
eta_relative: float = Field(title="ETA in secs")
|
||||
state: dict = Field(title="State", description="The current state snapshot")
|
||||
current_image: str = Field(default=None, title="Current image", description="The current image in base64 format. opts.show_progress_every_n_steps is required for this to work.")
|
||||
textinfo: str = Field(default=None, title="Info text", description="Info text used by WebUI.")
|
||||
|
||||
class InterrogateRequest(BaseModel):
|
||||
image: str = Field(default="", title="Image", description="Image to work on, must be a Base64 string containing the image's data.")
|
||||
|
@ -223,7 +224,8 @@ class UpscalerItem(BaseModel):
|
|||
class SDModelItem(BaseModel):
|
||||
title: str = Field(title="Title")
|
||||
model_name: str = Field(title="Model Name")
|
||||
hash: str = Field(title="Hash")
|
||||
hash: Optional[str] = Field(title="Short hash")
|
||||
sha256: Optional[str] = Field(title="sha256 hash")
|
||||
filename: str = Field(title="Filename")
|
||||
config: str = Field(title="Config file")
|
||||
|
||||
|
|
|
@ -3,6 +3,7 @@ import math
|
|||
import os
|
||||
import sys
|
||||
import traceback
|
||||
import shutil
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
@ -248,7 +249,32 @@ def run_pnginfo(image):
|
|||
return '', geninfo, info
|
||||
|
||||
|
||||
def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format):
|
||||
def create_config(ckpt_result, config_source, a, b, c):
|
||||
def config(x):
|
||||
return sd_models.find_checkpoint_config(x) if x else None
|
||||
|
||||
if config_source == 0:
|
||||
cfg = config(a) or config(b) or config(c)
|
||||
elif config_source == 1:
|
||||
cfg = config(b)
|
||||
elif config_source == 2:
|
||||
cfg = config(c)
|
||||
else:
|
||||
cfg = None
|
||||
|
||||
if cfg is None:
|
||||
return
|
||||
|
||||
filename, _ = os.path.splitext(ckpt_result)
|
||||
checkpoint_filename = filename + ".yaml"
|
||||
|
||||
print("Copying config:")
|
||||
print(" from:", cfg)
|
||||
print(" to:", checkpoint_filename)
|
||||
shutil.copyfile(cfg, checkpoint_filename)
|
||||
|
||||
|
||||
def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source):
|
||||
shared.state.begin()
|
||||
shared.state.job = 'model-merge'
|
||||
|
||||
|
@ -356,6 +382,8 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
|
|||
|
||||
sd_models.list_models()
|
||||
|
||||
create_config(output_modelname, config_source, primary_model_info, secondary_model_info, tertiary_model_info)
|
||||
|
||||
print("Checkpoint saved.")
|
||||
shared.state.textinfo = "Checkpoint saved to " + output_modelname
|
||||
shared.state.end()
|
||||
|
|
|
@ -7,7 +7,7 @@ from pathlib import Path
|
|||
|
||||
import gradio as gr
|
||||
from modules.shared import script_path
|
||||
from modules import shared, ui_tempdir
|
||||
from modules import shared, ui_tempdir, script_callbacks
|
||||
import tempfile
|
||||
from PIL import Image
|
||||
|
||||
|
@ -298,6 +298,7 @@ def connect_paste(button, paste_fields, input_comp, jsfunc=None):
|
|||
prompt = file.read()
|
||||
|
||||
params = parse_generation_parameters(prompt)
|
||||
script_callbacks.infotext_pasted_callback(prompt, params)
|
||||
res = []
|
||||
|
||||
for output, key in paste_fields:
|
||||
|
|
84
modules/hashes.py
Normal file
84
modules/hashes.py
Normal file
|
@ -0,0 +1,84 @@
|
|||
import hashlib
|
||||
import json
|
||||
import os.path
|
||||
|
||||
import filelock
|
||||
|
||||
|
||||
cache_filename = "cache.json"
|
||||
cache_data = None
|
||||
|
||||
|
||||
def dump_cache():
|
||||
with filelock.FileLock(cache_filename+".lock"):
|
||||
with open(cache_filename, "w", encoding="utf8") as file:
|
||||
json.dump(cache_data, file, indent=4)
|
||||
|
||||
|
||||
def cache(subsection):
|
||||
global cache_data
|
||||
|
||||
if cache_data is None:
|
||||
with filelock.FileLock(cache_filename+".lock"):
|
||||
if not os.path.isfile(cache_filename):
|
||||
cache_data = {}
|
||||
else:
|
||||
with open(cache_filename, "r", encoding="utf8") as file:
|
||||
cache_data = json.load(file)
|
||||
|
||||
s = cache_data.get(subsection, {})
|
||||
cache_data[subsection] = s
|
||||
|
||||
return s
|
||||
|
||||
|
||||
def calculate_sha256(filename):
|
||||
hash_sha256 = hashlib.sha256()
|
||||
|
||||
with open(filename, "rb") as f:
|
||||
for chunk in iter(lambda: f.read(4096), b""):
|
||||
hash_sha256.update(chunk)
|
||||
|
||||
return hash_sha256.hexdigest()
|
||||
|
||||
|
||||
def sha256_from_cache(filename, title):
|
||||
hashes = cache("hashes")
|
||||
ondisk_mtime = os.path.getmtime(filename)
|
||||
|
||||
if title not in hashes:
|
||||
return None
|
||||
|
||||
cached_sha256 = hashes[title].get("sha256", None)
|
||||
cached_mtime = hashes[title].get("mtime", 0)
|
||||
|
||||
if ondisk_mtime > cached_mtime or cached_sha256 is None:
|
||||
return None
|
||||
|
||||
return cached_sha256
|
||||
|
||||
|
||||
def sha256(filename, title):
|
||||
hashes = cache("hashes")
|
||||
|
||||
sha256_value = sha256_from_cache(filename, title)
|
||||
if sha256_value is not None:
|
||||
return sha256_value
|
||||
|
||||
print(f"Calculating sha256 for {filename}: ", end='')
|
||||
sha256_value = calculate_sha256(filename)
|
||||
print(f"{sha256_value}")
|
||||
|
||||
hashes[title] = {
|
||||
"mtime": os.path.getmtime(filename),
|
||||
"sha256": sha256_value,
|
||||
}
|
||||
|
||||
dump_cache()
|
||||
|
||||
return sha256_value
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
@ -12,7 +12,7 @@ import torch
|
|||
import tqdm
|
||||
from einops import rearrange, repeat
|
||||
from ldm.util import default
|
||||
from modules import devices, processing, sd_models, shared, sd_samplers
|
||||
from modules import devices, processing, sd_models, shared, sd_samplers, hashes
|
||||
from modules.textual_inversion import textual_inversion, logging
|
||||
from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
||||
from torch import einsum
|
||||
|
@ -24,7 +24,6 @@ from statistics import stdev, mean
|
|||
|
||||
optimizer_dict = {optim_name : cls_obj for optim_name, cls_obj in inspect.getmembers(torch.optim, inspect.isclass) if optim_name != "Optimizer"}
|
||||
|
||||
|
||||
class HypernetworkModule(torch.nn.Module):
|
||||
multiplier = 1.0
|
||||
activation_dict = {
|
||||
|
@ -226,7 +225,7 @@ class Hypernetwork:
|
|||
|
||||
torch.save(state_dict, filename)
|
||||
if shared.opts.save_optimizer_state and self.optimizer_state_dict:
|
||||
optimizer_saved_dict['hash'] = sd_models.model_hash(filename)
|
||||
optimizer_saved_dict['hash'] = self.shorthash()
|
||||
optimizer_saved_dict['optimizer_state_dict'] = self.optimizer_state_dict
|
||||
torch.save(optimizer_saved_dict, filename + '.optim')
|
||||
|
||||
|
@ -238,32 +237,33 @@ class Hypernetwork:
|
|||
state_dict = torch.load(filename, map_location='cpu')
|
||||
|
||||
self.layer_structure = state_dict.get('layer_structure', [1, 2, 1])
|
||||
print(self.layer_structure)
|
||||
optional_info = state_dict.get('optional_info', None)
|
||||
if optional_info is not None:
|
||||
print(f"INFO:\n {optional_info}\n")
|
||||
self.optional_info = optional_info
|
||||
self.optional_info = state_dict.get('optional_info', None)
|
||||
self.activation_func = state_dict.get('activation_func', None)
|
||||
print(f"Activation function is {self.activation_func}")
|
||||
self.weight_init = state_dict.get('weight_initialization', 'Normal')
|
||||
print(f"Weight initialization is {self.weight_init}")
|
||||
self.add_layer_norm = state_dict.get('is_layer_norm', False)
|
||||
print(f"Layer norm is set to {self.add_layer_norm}")
|
||||
self.dropout_structure = state_dict.get('dropout_structure', None)
|
||||
self.use_dropout = True if self.dropout_structure is not None and any(self.dropout_structure) else state_dict.get('use_dropout', False)
|
||||
print(f"Dropout usage is set to {self.use_dropout}" )
|
||||
self.activate_output = state_dict.get('activate_output', True)
|
||||
print(f"Activate last layer is set to {self.activate_output}")
|
||||
self.last_layer_dropout = state_dict.get('last_layer_dropout', False)
|
||||
# Dropout structure should have same length as layer structure, Every digits should be in [0,1), and last digit must be 0.
|
||||
if self.dropout_structure is None:
|
||||
print("Using previous dropout structure")
|
||||
self.dropout_structure = parse_dropout_structure(self.layer_structure, self.use_dropout, self.last_layer_dropout)
|
||||
print(f"Dropout structure is set to {self.dropout_structure}")
|
||||
|
||||
optimizer_saved_dict = torch.load(self.filename + '.optim', map_location = 'cpu') if os.path.exists(self.filename + '.optim') else {}
|
||||
if shared.opts.print_hypernet_extra:
|
||||
if self.optional_info is not None:
|
||||
print(f" INFO:\n {self.optional_info}\n")
|
||||
|
||||
if sd_models.model_hash(filename) == optimizer_saved_dict.get('hash', None):
|
||||
print(f" Layer structure: {self.layer_structure}")
|
||||
print(f" Activation function: {self.activation_func}")
|
||||
print(f" Weight initialization: {self.weight_init}")
|
||||
print(f" Layer norm: {self.add_layer_norm}")
|
||||
print(f" Dropout usage: {self.use_dropout}" )
|
||||
print(f" Activate last layer: {self.activate_output}")
|
||||
print(f" Dropout structure: {self.dropout_structure}")
|
||||
|
||||
optimizer_saved_dict = torch.load(self.filename + '.optim', map_location='cpu') if os.path.exists(self.filename + '.optim') else {}
|
||||
|
||||
if self.shorthash() == optimizer_saved_dict.get('hash', None):
|
||||
self.optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
|
||||
else:
|
||||
self.optimizer_state_dict = None
|
||||
|
@ -290,6 +290,11 @@ class Hypernetwork:
|
|||
self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None)
|
||||
self.eval()
|
||||
|
||||
def shorthash(self):
|
||||
sha256 = hashes.sha256(self.filename, f'hypernet/{self.name}')
|
||||
|
||||
return sha256[0:10]
|
||||
|
||||
|
||||
def list_hypernetworks(path):
|
||||
res = {}
|
||||
|
@ -297,7 +302,7 @@ def list_hypernetworks(path):
|
|||
name = os.path.splitext(os.path.basename(filename))[0]
|
||||
# Prevent a hypothetical "None.pt" from being listed.
|
||||
if name != "None":
|
||||
res[name + f"({sd_models.model_hash(filename)})"] = filename
|
||||
res[name] = filename
|
||||
return res
|
||||
|
||||
|
||||
|
@ -498,6 +503,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
|
|||
if clip_grad:
|
||||
clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, initial_step, verbose=False)
|
||||
|
||||
if shared.opts.training_enable_tensorboard:
|
||||
tensorboard_writer = textual_inversion.tensorboard_setup(log_directory)
|
||||
|
||||
# dataset loading may take a while, so input validations and early returns should be done before this
|
||||
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
||||
|
||||
|
@ -507,7 +515,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
|
|||
|
||||
if shared.opts.save_training_settings_to_txt:
|
||||
saved_params = dict(
|
||||
model_name=checkpoint.model_name, model_hash=checkpoint.hash, num_of_dataset_images=len(ds),
|
||||
model_name=checkpoint.model_name, model_hash=checkpoint.shorthash, num_of_dataset_images=len(ds),
|
||||
**{field: getattr(hypernetwork, field) for field in ['layer_structure', 'activation_func', 'weight_init', 'add_layer_norm', 'use_dropout', ]}
|
||||
)
|
||||
logging.save_settings_to_file(log_directory, {**saved_params, **locals()})
|
||||
|
@ -619,7 +627,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
|
|||
epoch_num = hypernetwork.step // steps_per_epoch
|
||||
epoch_step = hypernetwork.step % steps_per_epoch
|
||||
|
||||
pbar.set_description(f"[Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}")
|
||||
description = f"Training hypernetwork [Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}"
|
||||
pbar.set_description(description)
|
||||
shared.state.textinfo = description
|
||||
if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0:
|
||||
# Before saving, change name to match current checkpoint.
|
||||
hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}'
|
||||
|
@ -630,6 +640,14 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
|
|||
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, last_saved_file)
|
||||
hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
|
||||
|
||||
|
||||
|
||||
if shared.opts.training_enable_tensorboard:
|
||||
epoch_num = hypernetwork.step // len(ds)
|
||||
epoch_step = hypernetwork.step - (epoch_num * len(ds)) + 1
|
||||
|
||||
textual_inversion.tensorboard_add(tensorboard_writer, loss=mean_loss, global_step=hypernetwork.step, step=epoch_step, learn_rate=scheduler.learn_rate, epoch_num=epoch_num)
|
||||
|
||||
textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, steps_per_epoch, {
|
||||
"loss": f"{loss_step:.7f}",
|
||||
"learn_rate": scheduler.learn_rate
|
||||
|
@ -671,6 +689,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
|
|||
|
||||
processed = processing.process_images(p)
|
||||
image = processed.images[0] if len(processed.images) > 0 else None
|
||||
|
||||
if shared.opts.training_enable_tensorboard and shared.opts.training_tensorboard_save_images:
|
||||
textual_inversion.tensorboard_add_image(tensorboard_writer, f"Validation at epoch {epoch_num}", image, hypernetwork.step)
|
||||
|
||||
if unload:
|
||||
shared.sd_model.cond_stage_model.to(devices.cpu)
|
||||
|
@ -722,7 +743,7 @@ def save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename):
|
|||
old_sd_checkpoint = hypernetwork.sd_checkpoint if hasattr(hypernetwork, "sd_checkpoint") else None
|
||||
old_sd_checkpoint_name = hypernetwork.sd_checkpoint_name if hasattr(hypernetwork, "sd_checkpoint_name") else None
|
||||
try:
|
||||
hypernetwork.sd_checkpoint = checkpoint.hash
|
||||
hypernetwork.sd_checkpoint = checkpoint.shorthash
|
||||
hypernetwork.sd_checkpoint_name = checkpoint.model_name
|
||||
hypernetwork.name = hypernetwork_name
|
||||
hypernetwork.save(filename)
|
||||
|
|
|
@ -59,38 +59,34 @@ def process_batch(p, input_dir, output_dir, args):
|
|||
processed_image.save(os.path.join(output_dir, filename))
|
||||
|
||||
|
||||
def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, init_img_with_mask, init_img_with_mask_orig, init_img_inpaint, init_mask_inpaint, mask_mode, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args):
|
||||
is_inpaint = mode == 1
|
||||
is_batch = mode == 2
|
||||
def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args):
|
||||
is_batch = mode == 5
|
||||
|
||||
if is_inpaint:
|
||||
# Drawn mask
|
||||
if mask_mode == 0:
|
||||
is_mask_sketch = isinstance(init_img_with_mask, dict)
|
||||
is_mask_paint = not is_mask_sketch
|
||||
if is_mask_sketch:
|
||||
# Sketch: mask iff. not transparent
|
||||
image, mask = init_img_with_mask["image"], init_img_with_mask["mask"]
|
||||
alpha_mask = ImageOps.invert(image.split()[-1]).convert('L').point(lambda x: 255 if x > 0 else 0, mode='1')
|
||||
mask = ImageChops.lighter(alpha_mask, mask.convert('L')).convert('L')
|
||||
else:
|
||||
# Color-sketch: mask iff. painted over
|
||||
image = init_img_with_mask
|
||||
orig = init_img_with_mask_orig or init_img_with_mask
|
||||
pred = np.any(np.array(image) != np.array(orig), axis=-1)
|
||||
mask = Image.fromarray(pred.astype(np.uint8) * 255, "L")
|
||||
mask = ImageEnhance.Brightness(mask).enhance(1 - mask_alpha / 100)
|
||||
blur = ImageFilter.GaussianBlur(mask_blur)
|
||||
image = Image.composite(image.filter(blur), orig, mask.filter(blur))
|
||||
|
||||
image = image.convert("RGB")
|
||||
# Uploaded mask
|
||||
else:
|
||||
image = init_img_inpaint
|
||||
mask = init_mask_inpaint
|
||||
# No mask
|
||||
if mode == 0: # img2img
|
||||
image = init_img.convert("RGB")
|
||||
mask = None
|
||||
elif mode == 1: # img2img sketch
|
||||
image = sketch.convert("RGB")
|
||||
mask = None
|
||||
elif mode == 2: # inpaint
|
||||
image, mask = init_img_with_mask["image"], init_img_with_mask["mask"]
|
||||
alpha_mask = ImageOps.invert(image.split()[-1]).convert('L').point(lambda x: 255 if x > 0 else 0, mode='1')
|
||||
mask = ImageChops.lighter(alpha_mask, mask.convert('L')).convert('L')
|
||||
image = image.convert("RGB")
|
||||
elif mode == 3: # inpaint sketch
|
||||
image = inpaint_color_sketch
|
||||
orig = inpaint_color_sketch_orig or inpaint_color_sketch
|
||||
pred = np.any(np.array(image) != np.array(orig), axis=-1)
|
||||
mask = Image.fromarray(pred.astype(np.uint8) * 255, "L")
|
||||
mask = ImageEnhance.Brightness(mask).enhance(1 - mask_alpha / 100)
|
||||
blur = ImageFilter.GaussianBlur(mask_blur)
|
||||
image = Image.composite(image.filter(blur), orig, mask.filter(blur))
|
||||
image = image.convert("RGB")
|
||||
elif mode == 4: # inpaint upload mask
|
||||
image = init_img_inpaint
|
||||
mask = init_mask_inpaint
|
||||
else:
|
||||
image = init_img
|
||||
image = None
|
||||
mask = None
|
||||
|
||||
# Use the EXIF orientation of photos taken by smartphones.
|
||||
|
|
|
@ -437,7 +437,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
|
|||
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
|
||||
"Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
|
||||
"Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name),
|
||||
"Hypernet hash": (None if shared.loaded_hypernetwork is None else sd_models.model_hash(shared.loaded_hypernetwork.filename)),
|
||||
"Hypernet hash": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.shorthash()),
|
||||
"Hypernet strength": (None if shared.loaded_hypernetwork is None or shared.opts.sd_hypernetwork_strength >= 1 else shared.opts.sd_hypernetwork_strength),
|
||||
"Batch size": (None if p.batch_size < 2 else p.batch_size),
|
||||
"Batch pos": (None if p.batch_size < 2 else position_in_batch),
|
||||
|
@ -531,16 +531,16 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||
def infotext(iteration=0, position_in_batch=0):
|
||||
return create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, comments, iteration, position_in_batch)
|
||||
|
||||
with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file:
|
||||
processed = Processed(p, [], p.seed, "")
|
||||
file.write(processed.infotext(p, 0))
|
||||
|
||||
if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings:
|
||||
model_hijack.embedding_db.load_textual_inversion_embeddings()
|
||||
|
||||
if p.scripts is not None:
|
||||
p.scripts.process(p)
|
||||
|
||||
with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file:
|
||||
processed = Processed(p, [], p.seed, "")
|
||||
file.write(processed.infotext(p, 0))
|
||||
|
||||
infotexts = []
|
||||
output_images = []
|
||||
|
||||
|
|
|
@ -49,6 +49,8 @@ def get_learned_conditioning_prompt_schedules(prompts, steps):
|
|||
[[5, 'a c'], [10, 'a {b|d{ c']]
|
||||
>>> g("((a][:b:c [d:3]")
|
||||
[[3, '((a][:b:c '], [10, '((a][:b:c d']]
|
||||
>>> g("[a|(b:1.1)]")
|
||||
[[1, 'a'], [2, '(b:1.1)'], [3, 'a'], [4, '(b:1.1)'], [5, 'a'], [6, '(b:1.1)'], [7, 'a'], [8, '(b:1.1)'], [9, 'a'], [10, '(b:1.1)']]
|
||||
"""
|
||||
|
||||
def collect_steps(steps, tree):
|
||||
|
@ -84,7 +86,7 @@ def get_learned_conditioning_prompt_schedules(prompts, steps):
|
|||
yield args[0].value
|
||||
def __default__(self, data, children, meta):
|
||||
for child in children:
|
||||
yield from child
|
||||
yield child
|
||||
return AtStep().transform(tree)
|
||||
|
||||
def get_schedule(prompt):
|
||||
|
|
|
@ -2,7 +2,7 @@ import sys
|
|||
import traceback
|
||||
from collections import namedtuple
|
||||
import inspect
|
||||
from typing import Optional
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
from fastapi import FastAPI
|
||||
from gradio import Blocks
|
||||
|
@ -71,6 +71,7 @@ callback_map = dict(
|
|||
callbacks_before_component=[],
|
||||
callbacks_after_component=[],
|
||||
callbacks_image_grid=[],
|
||||
callbacks_infotext_pasted=[],
|
||||
callbacks_script_unloaded=[],
|
||||
)
|
||||
|
||||
|
@ -172,6 +173,14 @@ def image_grid_callback(params: ImageGridLoopParams):
|
|||
report_exception(c, 'image_grid')
|
||||
|
||||
|
||||
def infotext_pasted_callback(infotext: str, params: Dict[str, Any]):
|
||||
for c in callback_map['callbacks_infotext_pasted']:
|
||||
try:
|
||||
c.callback(infotext, params)
|
||||
except Exception:
|
||||
report_exception(c, 'infotext_pasted')
|
||||
|
||||
|
||||
def script_unloaded_callback():
|
||||
for c in reversed(callback_map['callbacks_script_unloaded']):
|
||||
try:
|
||||
|
@ -290,6 +299,15 @@ def on_image_grid(callback):
|
|||
add_callback(callback_map['callbacks_image_grid'], callback)
|
||||
|
||||
|
||||
def on_infotext_pasted(callback):
|
||||
"""register a function to be called before applying an infotext.
|
||||
The callback is called with two arguments:
|
||||
- infotext: str - raw infotext.
|
||||
- result: Dict[str, any] - parsed infotext parameters.
|
||||
"""
|
||||
add_callback(callback_map['callbacks_infotext_pasted'], callback)
|
||||
|
||||
|
||||
def on_script_unloaded(callback):
|
||||
"""register a function to be called before the script is unloaded. Any hooks/hijacks/monkeying about that
|
||||
the script did should be reverted here"""
|
||||
|
|
|
@ -152,7 +152,7 @@ def basedir():
|
|||
|
||||
scripts_data = []
|
||||
ScriptFile = namedtuple("ScriptFile", ["basedir", "filename", "path"])
|
||||
ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir"])
|
||||
ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir", "module"])
|
||||
|
||||
|
||||
def list_scripts(scriptdirname, extension):
|
||||
|
@ -206,7 +206,7 @@ def load_scripts():
|
|||
|
||||
for key, script_class in module.__dict__.items():
|
||||
if type(script_class) == type and issubclass(script_class, Script):
|
||||
scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir))
|
||||
scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module))
|
||||
|
||||
except Exception:
|
||||
print(f"Error loading script: {scriptfile.filename}", file=sys.stderr)
|
||||
|
@ -241,7 +241,7 @@ class ScriptRunner:
|
|||
self.alwayson_scripts.clear()
|
||||
self.selectable_scripts.clear()
|
||||
|
||||
for script_class, path, basedir in scripts_data:
|
||||
for script_class, path, basedir, script_module in scripts_data:
|
||||
script = script_class()
|
||||
script.filename = path
|
||||
script.is_txt2img = not is_img2img
|
||||
|
|
|
@ -20,6 +20,19 @@ class DisableInitialization:
|
|||
```
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.replaced = []
|
||||
|
||||
def replace(self, obj, field, func):
|
||||
original = getattr(obj, field, None)
|
||||
if original is None:
|
||||
return None
|
||||
|
||||
self.replaced.append((obj, field, original))
|
||||
setattr(obj, field, func)
|
||||
|
||||
return original
|
||||
|
||||
def __enter__(self):
|
||||
def do_nothing(*args, **kwargs):
|
||||
pass
|
||||
|
@ -37,11 +50,14 @@ class DisableInitialization:
|
|||
def transformers_utils_hub_get_file_from_cache(original, url, *args, **kwargs):
|
||||
|
||||
# this file is always 404, prevent making request
|
||||
if url == 'https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/added_tokens.json':
|
||||
raise transformers.utils.hub.EntryNotFoundError
|
||||
if url == 'https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/added_tokens.json' or url == 'openai/clip-vit-large-patch14' and args[0] == 'added_tokens.json':
|
||||
return None
|
||||
|
||||
try:
|
||||
return original(url, *args, local_files_only=True, **kwargs)
|
||||
res = original(url, *args, local_files_only=True, **kwargs)
|
||||
if res is None:
|
||||
res = original(url, *args, local_files_only=False, **kwargs)
|
||||
return res
|
||||
except Exception as e:
|
||||
return original(url, *args, local_files_only=False, **kwargs)
|
||||
|
||||
|
@ -54,42 +70,19 @@ class DisableInitialization:
|
|||
def transformers_configuration_utils_cached_file(url, *args, local_files_only=False, **kwargs):
|
||||
return transformers_utils_hub_get_file_from_cache(self.transformers_configuration_utils_cached_file, url, *args, **kwargs)
|
||||
|
||||
self.init_kaiming_uniform = torch.nn.init.kaiming_uniform_
|
||||
self.init_no_grad_normal = torch.nn.init._no_grad_normal_
|
||||
self.init_no_grad_uniform_ = torch.nn.init._no_grad_uniform_
|
||||
self.create_model_and_transforms = open_clip.create_model_and_transforms
|
||||
self.CLIPTextModel_from_pretrained = ldm.modules.encoders.modules.CLIPTextModel.from_pretrained
|
||||
self.transformers_modeling_utils_load_pretrained_model = getattr(transformers.modeling_utils.PreTrainedModel, '_load_pretrained_model', None)
|
||||
self.transformers_tokenization_utils_base_cached_file = getattr(transformers.tokenization_utils_base, 'cached_file', None)
|
||||
self.transformers_configuration_utils_cached_file = getattr(transformers.configuration_utils, 'cached_file', None)
|
||||
self.transformers_utils_hub_get_from_cache = getattr(transformers.utils.hub, 'get_from_cache', None)
|
||||
|
||||
torch.nn.init.kaiming_uniform_ = do_nothing
|
||||
torch.nn.init._no_grad_normal_ = do_nothing
|
||||
torch.nn.init._no_grad_uniform_ = do_nothing
|
||||
open_clip.create_model_and_transforms = create_model_and_transforms_without_pretrained
|
||||
ldm.modules.encoders.modules.CLIPTextModel.from_pretrained = CLIPTextModel_from_pretrained
|
||||
if self.transformers_modeling_utils_load_pretrained_model is not None:
|
||||
transformers.modeling_utils.PreTrainedModel._load_pretrained_model = transformers_modeling_utils_load_pretrained_model
|
||||
if self.transformers_tokenization_utils_base_cached_file is not None:
|
||||
transformers.tokenization_utils_base.cached_file = transformers_tokenization_utils_base_cached_file
|
||||
if self.transformers_configuration_utils_cached_file is not None:
|
||||
transformers.configuration_utils.cached_file = transformers_configuration_utils_cached_file
|
||||
if self.transformers_utils_hub_get_from_cache is not None:
|
||||
transformers.utils.hub.get_from_cache = transformers_utils_hub_get_from_cache
|
||||
self.replace(torch.nn.init, 'kaiming_uniform_', do_nothing)
|
||||
self.replace(torch.nn.init, '_no_grad_normal_', do_nothing)
|
||||
self.replace(torch.nn.init, '_no_grad_uniform_', do_nothing)
|
||||
self.create_model_and_transforms = self.replace(open_clip, 'create_model_and_transforms', create_model_and_transforms_without_pretrained)
|
||||
self.CLIPTextModel_from_pretrained = self.replace(ldm.modules.encoders.modules.CLIPTextModel, 'from_pretrained', CLIPTextModel_from_pretrained)
|
||||
self.transformers_modeling_utils_load_pretrained_model = self.replace(transformers.modeling_utils.PreTrainedModel, '_load_pretrained_model', transformers_modeling_utils_load_pretrained_model)
|
||||
self.transformers_tokenization_utils_base_cached_file = self.replace(transformers.tokenization_utils_base, 'cached_file', transformers_tokenization_utils_base_cached_file)
|
||||
self.transformers_configuration_utils_cached_file = self.replace(transformers.configuration_utils, 'cached_file', transformers_configuration_utils_cached_file)
|
||||
self.transformers_utils_hub_get_from_cache = self.replace(transformers.utils.hub, 'get_from_cache', transformers_utils_hub_get_from_cache)
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
torch.nn.init.kaiming_uniform_ = self.init_kaiming_uniform
|
||||
torch.nn.init._no_grad_normal_ = self.init_no_grad_normal
|
||||
torch.nn.init._no_grad_uniform_ = self.init_no_grad_uniform_
|
||||
open_clip.create_model_and_transforms = self.create_model_and_transforms
|
||||
ldm.modules.encoders.modules.CLIPTextModel.from_pretrained = self.CLIPTextModel_from_pretrained
|
||||
if self.transformers_modeling_utils_load_pretrained_model is not None:
|
||||
transformers.modeling_utils.PreTrainedModel._load_pretrained_model = self.transformers_modeling_utils_load_pretrained_model
|
||||
if self.transformers_tokenization_utils_base_cached_file is not None:
|
||||
transformers.utils.hub.cached_file = self.transformers_tokenization_utils_base_cached_file
|
||||
if self.transformers_configuration_utils_cached_file is not None:
|
||||
transformers.utils.hub.cached_file = self.transformers_configuration_utils_cached_file
|
||||
if self.transformers_utils_hub_get_from_cache is not None:
|
||||
transformers.utils.hub.get_from_cache = self.transformers_utils_hub_get_from_cache
|
||||
for obj, field, original in self.replaced:
|
||||
setattr(obj, field, original)
|
||||
|
||||
self.replaced.clear()
|
||||
|
||||
|
|
|
@ -14,17 +14,58 @@ import ldm.modules.midas as midas
|
|||
|
||||
from ldm.util import instantiate_from_config
|
||||
|
||||
from modules import shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors
|
||||
from modules import shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes
|
||||
from modules.paths import models_path
|
||||
from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting
|
||||
|
||||
model_dir = "Stable-diffusion"
|
||||
model_path = os.path.abspath(os.path.join(models_path, model_dir))
|
||||
|
||||
CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name'])
|
||||
checkpoints_list = {}
|
||||
checkpoint_alisases = {}
|
||||
checkpoints_loaded = collections.OrderedDict()
|
||||
|
||||
|
||||
class CheckpointInfo:
|
||||
def __init__(self, filename):
|
||||
self.filename = filename
|
||||
abspath = os.path.abspath(filename)
|
||||
|
||||
if shared.cmd_opts.ckpt_dir is not None and abspath.startswith(shared.cmd_opts.ckpt_dir):
|
||||
name = abspath.replace(shared.cmd_opts.ckpt_dir, '')
|
||||
elif abspath.startswith(model_path):
|
||||
name = abspath.replace(model_path, '')
|
||||
else:
|
||||
name = os.path.basename(filename)
|
||||
|
||||
if name.startswith("\\") or name.startswith("/"):
|
||||
name = name[1:]
|
||||
|
||||
self.title = name
|
||||
self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]
|
||||
self.hash = model_hash(filename)
|
||||
|
||||
self.sha256 = hashes.sha256_from_cache(self.filename, "checkpoint/" + self.title)
|
||||
self.shorthash = self.sha256[0:10] if self.sha256 else None
|
||||
|
||||
self.ids = [self.hash, self.model_name, self.title, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256] if self.shorthash else [])
|
||||
|
||||
def register(self):
|
||||
checkpoints_list[self.title] = self
|
||||
for id in self.ids:
|
||||
checkpoint_alisases[id] = self
|
||||
|
||||
def calculate_shorthash(self):
|
||||
self.sha256 = hashes.sha256(self.filename, "checkpoint/" + self.title)
|
||||
self.shorthash = self.sha256[0:10]
|
||||
|
||||
if self.shorthash not in self.ids:
|
||||
self.ids += [self.shorthash, self.sha256]
|
||||
self.register()
|
||||
|
||||
return self.shorthash
|
||||
|
||||
|
||||
try:
|
||||
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
|
||||
|
||||
|
@ -43,10 +84,14 @@ def setup_model():
|
|||
enable_midas_autodownload()
|
||||
|
||||
|
||||
def checkpoint_tiles():
|
||||
convert = lambda name: int(name) if name.isdigit() else name.lower()
|
||||
alphanumeric_key = lambda key: [convert(c) for c in re.split('([0-9]+)', key)]
|
||||
return sorted([x.title for x in checkpoints_list.values()], key = alphanumeric_key)
|
||||
def checkpoint_tiles():
|
||||
def convert(name):
|
||||
return int(name) if name.isdigit() else name.lower()
|
||||
|
||||
def alphanumeric_key(key):
|
||||
return [convert(c) for c in re.split('([0-9]+)', key)]
|
||||
|
||||
return sorted([x.title for x in checkpoints_list.values()], key=alphanumeric_key)
|
||||
|
||||
|
||||
def find_checkpoint_config(info):
|
||||
|
@ -62,48 +107,38 @@ def find_checkpoint_config(info):
|
|||
|
||||
def list_models():
|
||||
checkpoints_list.clear()
|
||||
checkpoint_alisases.clear()
|
||||
model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"], ext_blacklist=[".vae.safetensors"])
|
||||
|
||||
def modeltitle(path, shorthash):
|
||||
abspath = os.path.abspath(path)
|
||||
|
||||
if shared.cmd_opts.ckpt_dir is not None and abspath.startswith(shared.cmd_opts.ckpt_dir):
|
||||
name = abspath.replace(shared.cmd_opts.ckpt_dir, '')
|
||||
elif abspath.startswith(model_path):
|
||||
name = abspath.replace(model_path, '')
|
||||
else:
|
||||
name = os.path.basename(path)
|
||||
|
||||
if name.startswith("\\") or name.startswith("/"):
|
||||
name = name[1:]
|
||||
|
||||
shortname = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]
|
||||
|
||||
return f'{name} [{shorthash}]', shortname
|
||||
|
||||
cmd_ckpt = shared.cmd_opts.ckpt
|
||||
if os.path.exists(cmd_ckpt):
|
||||
h = model_hash(cmd_ckpt)
|
||||
title, short_model_name = modeltitle(cmd_ckpt, h)
|
||||
checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name)
|
||||
shared.opts.data['sd_model_checkpoint'] = title
|
||||
checkpoint_info = CheckpointInfo(cmd_ckpt)
|
||||
checkpoint_info.register()
|
||||
|
||||
shared.opts.data['sd_model_checkpoint'] = checkpoint_info.title
|
||||
elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
|
||||
print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr)
|
||||
|
||||
for filename in model_list:
|
||||
h = model_hash(filename)
|
||||
title, short_model_name = modeltitle(filename, h)
|
||||
|
||||
checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name)
|
||||
checkpoint_info = CheckpointInfo(filename)
|
||||
checkpoint_info.register()
|
||||
|
||||
|
||||
def get_closet_checkpoint_match(searchString):
|
||||
applicable = sorted([info for info in checkpoints_list.values() if searchString in info.title], key = lambda x:len(x.title))
|
||||
if len(applicable) > 0:
|
||||
return applicable[0]
|
||||
def get_closet_checkpoint_match(search_string):
|
||||
checkpoint_info = checkpoint_alisases.get(search_string, None)
|
||||
if checkpoint_info is not None:
|
||||
return checkpoint_info
|
||||
|
||||
found = sorted([info for info in checkpoints_list.values() if search_string in info.title], key=lambda x: len(x.title))
|
||||
if found:
|
||||
return found[0]
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def model_hash(filename):
|
||||
"""old hash that only looks at a small part of the file and is prone to collisions"""
|
||||
|
||||
try:
|
||||
with open(filename, "rb") as file:
|
||||
import hashlib
|
||||
|
@ -119,7 +154,7 @@ def model_hash(filename):
|
|||
def select_checkpoint():
|
||||
model_checkpoint = shared.opts.sd_model_checkpoint
|
||||
|
||||
checkpoint_info = checkpoints_list.get(model_checkpoint, None)
|
||||
checkpoint_info = checkpoint_alisases.get(model_checkpoint, None)
|
||||
if checkpoint_info is not None:
|
||||
return checkpoint_info
|
||||
|
||||
|
@ -189,9 +224,8 @@ def read_state_dict(checkpoint_file, print_global_state=False, map_location=None
|
|||
return sd
|
||||
|
||||
|
||||
def load_model_weights(model, checkpoint_info, vae_file="auto"):
|
||||
checkpoint_file = checkpoint_info.filename
|
||||
sd_model_hash = checkpoint_info.hash
|
||||
def load_model_weights(model, checkpoint_info: CheckpointInfo):
|
||||
sd_model_hash = checkpoint_info.calculate_shorthash()
|
||||
|
||||
cache_enabled = shared.opts.sd_checkpoint_cache > 0
|
||||
|
||||
|
@ -201,9 +235,9 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"):
|
|||
model.load_state_dict(checkpoints_loaded[checkpoint_info])
|
||||
else:
|
||||
# load from file
|
||||
print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}")
|
||||
print(f"Loading weights [{sd_model_hash}] from {checkpoint_info.filename}")
|
||||
|
||||
sd = read_state_dict(checkpoint_file)
|
||||
sd = read_state_dict(checkpoint_info.filename)
|
||||
model.load_state_dict(sd, strict=False)
|
||||
del sd
|
||||
|
||||
|
@ -235,15 +269,16 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"):
|
|||
checkpoints_loaded.popitem(last=False) # LRU
|
||||
|
||||
model.sd_model_hash = sd_model_hash
|
||||
model.sd_model_checkpoint = checkpoint_file
|
||||
model.sd_model_checkpoint = checkpoint_info.filename
|
||||
model.sd_checkpoint_info = checkpoint_info
|
||||
shared.opts.data["sd_checkpoint_hash"] = checkpoint_info.sha256
|
||||
|
||||
model.logvar = model.logvar.to(devices.device) # fix for training
|
||||
|
||||
sd_vae.delete_base_vae()
|
||||
sd_vae.clear_loaded_vae()
|
||||
vae_file = sd_vae.resolve_vae(checkpoint_file, vae_file=vae_file)
|
||||
sd_vae.load_vae(model, vae_file)
|
||||
vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename)
|
||||
sd_vae.load_vae(model, vae_file, vae_source)
|
||||
|
||||
|
||||
def enable_midas_autodownload():
|
||||
|
@ -333,10 +368,15 @@ def load_model(checkpoint_info=None):
|
|||
|
||||
timer = Timer()
|
||||
|
||||
sd_model = None
|
||||
|
||||
try:
|
||||
with sd_disable_initialization.DisableInitialization():
|
||||
sd_model = instantiate_from_config(sd_config.model)
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
if sd_model is None:
|
||||
print('Failed to create model quickly; will retry using slow method.', file=sys.stderr)
|
||||
sd_model = instantiate_from_config(sd_config.model)
|
||||
|
||||
|
|
|
@ -138,7 +138,7 @@ def samples_to_image_grid(samples, approximation=None):
|
|||
def store_latent(decoded):
|
||||
state.current_latent = decoded
|
||||
|
||||
if opts.show_progress_every_n_steps > 0 and shared.state.sampling_step % opts.show_progress_every_n_steps == 0:
|
||||
if opts.live_previews_enable and opts.show_progress_every_n_steps > 0 and shared.state.sampling_step % opts.show_progress_every_n_steps == 0:
|
||||
if not shared.parallel_processing_allowed:
|
||||
shared.state.current_image = sample_to_image(decoded)
|
||||
|
||||
|
@ -243,7 +243,7 @@ class VanillaStableDiffusionSampler:
|
|||
self.nmask = p.nmask if hasattr(p, 'nmask') else None
|
||||
|
||||
def adjust_steps_if_invalid(self, p, num_steps):
|
||||
if (self.config.name == 'DDIM' and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS'):
|
||||
if (self.config.name == 'DDIM' and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS'):
|
||||
valid_step = 999 / (1000 // num_steps)
|
||||
if valid_step == floor(valid_step):
|
||||
return int(valid_step) + 1
|
||||
|
@ -266,8 +266,7 @@ class VanillaStableDiffusionSampler:
|
|||
if image_conditioning is not None:
|
||||
conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]}
|
||||
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
|
||||
|
||||
|
||||
|
||||
samples = self.launch_sampling(t_enc + 1, lambda: self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning))
|
||||
|
||||
return samples
|
||||
|
@ -352,6 +351,11 @@ class CFGDenoiser(torch.nn.Module):
|
|||
|
||||
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond={"c_crossattn": [uncond], "c_concat": [image_cond_in[-uncond.shape[0]:]]})
|
||||
|
||||
if opts.live_preview_content == "Prompt":
|
||||
store_latent(x_out[0:uncond.shape[0]])
|
||||
elif opts.live_preview_content == "Negative prompt":
|
||||
store_latent(x_out[-uncond.shape[0]:])
|
||||
|
||||
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
|
||||
|
||||
if self.mask is not None:
|
||||
|
@ -423,7 +427,8 @@ class KDiffusionSampler:
|
|||
def callback_state(self, d):
|
||||
step = d['i']
|
||||
latent = d["denoised"]
|
||||
store_latent(latent)
|
||||
if opts.live_preview_content == "Combined":
|
||||
store_latent(latent)
|
||||
self.last_latent = latent
|
||||
|
||||
if self.stop_at is not None and step > self.stop_at:
|
||||
|
|
|
@ -9,23 +9,9 @@ import glob
|
|||
from copy import deepcopy
|
||||
|
||||
|
||||
model_dir = "Stable-diffusion"
|
||||
model_path = os.path.abspath(os.path.join(models_path, model_dir))
|
||||
vae_dir = "VAE"
|
||||
vae_path = os.path.abspath(os.path.join(models_path, vae_dir))
|
||||
|
||||
|
||||
vae_path = os.path.abspath(os.path.join(models_path, "VAE"))
|
||||
vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"}
|
||||
|
||||
|
||||
default_vae_dict = {"auto": "auto", "None": None, None: None}
|
||||
default_vae_list = ["auto", "None"]
|
||||
|
||||
|
||||
default_vae_values = [default_vae_dict[x] for x in default_vae_list]
|
||||
vae_dict = dict(default_vae_dict)
|
||||
vae_list = list(default_vae_list)
|
||||
first_load = True
|
||||
vae_dict = {}
|
||||
|
||||
|
||||
base_vae = None
|
||||
|
@ -64,100 +50,69 @@ def restore_base_vae(model):
|
|||
|
||||
|
||||
def get_filename(filepath):
|
||||
return os.path.splitext(os.path.basename(filepath))[0]
|
||||
return os.path.basename(filepath)
|
||||
|
||||
|
||||
def refresh_vae_list(vae_path=vae_path, model_path=model_path):
|
||||
global vae_dict, vae_list
|
||||
res = {}
|
||||
candidates = [
|
||||
*glob.iglob(os.path.join(model_path, '**/*.vae.ckpt'), recursive=True),
|
||||
*glob.iglob(os.path.join(model_path, '**/*.vae.pt'), recursive=True),
|
||||
*glob.iglob(os.path.join(model_path, '**/*.vae.safetensors'), recursive=True),
|
||||
*glob.iglob(os.path.join(vae_path, '**/*.ckpt'), recursive=True),
|
||||
*glob.iglob(os.path.join(vae_path, '**/*.pt'), recursive=True),
|
||||
*glob.iglob(os.path.join(vae_path, '**/*.safetensors'), recursive=True),
|
||||
def refresh_vae_list():
|
||||
vae_dict.clear()
|
||||
|
||||
paths = [
|
||||
os.path.join(sd_models.model_path, '**/*.vae.ckpt'),
|
||||
os.path.join(sd_models.model_path, '**/*.vae.pt'),
|
||||
os.path.join(sd_models.model_path, '**/*.vae.safetensors'),
|
||||
os.path.join(vae_path, '**/*.ckpt'),
|
||||
os.path.join(vae_path, '**/*.pt'),
|
||||
os.path.join(vae_path, '**/*.safetensors'),
|
||||
]
|
||||
if shared.cmd_opts.vae_path is not None and os.path.isfile(shared.cmd_opts.vae_path):
|
||||
candidates.append(shared.cmd_opts.vae_path)
|
||||
|
||||
if shared.cmd_opts.ckpt_dir is not None and os.path.isdir(shared.cmd_opts.ckpt_dir):
|
||||
paths += [
|
||||
os.path.join(shared.cmd_opts.ckpt_dir, '**/*.vae.ckpt'),
|
||||
os.path.join(shared.cmd_opts.ckpt_dir, '**/*.vae.pt'),
|
||||
os.path.join(shared.cmd_opts.ckpt_dir, '**/*.vae.safetensors'),
|
||||
]
|
||||
|
||||
candidates = []
|
||||
for path in paths:
|
||||
candidates += glob.iglob(path, recursive=True)
|
||||
|
||||
for filepath in candidates:
|
||||
name = get_filename(filepath)
|
||||
res[name] = filepath
|
||||
vae_list.clear()
|
||||
vae_list.extend(default_vae_list)
|
||||
vae_list.extend(list(res.keys()))
|
||||
vae_dict.clear()
|
||||
vae_dict.update(res)
|
||||
vae_dict.update(default_vae_dict)
|
||||
return vae_list
|
||||
vae_dict[name] = filepath
|
||||
|
||||
|
||||
def get_vae_from_settings(vae_file="auto"):
|
||||
# else, we load from settings, if not set to be default
|
||||
if vae_file == "auto" and shared.opts.sd_vae is not None:
|
||||
# if saved VAE settings isn't recognized, fallback to auto
|
||||
vae_file = vae_dict.get(shared.opts.sd_vae, "auto")
|
||||
# if VAE selected but not found, fallback to auto
|
||||
if vae_file not in default_vae_values and not os.path.isfile(vae_file):
|
||||
vae_file = "auto"
|
||||
print(f"Selected VAE doesn't exist: {vae_file}")
|
||||
return vae_file
|
||||
def find_vae_near_checkpoint(checkpoint_file):
|
||||
checkpoint_path = os.path.splitext(checkpoint_file)[0]
|
||||
for vae_location in [checkpoint_path + ".vae.pt", checkpoint_path + ".vae.ckpt", checkpoint_path + ".vae.safetensors"]:
|
||||
if os.path.isfile(vae_location):
|
||||
return vae_location
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def resolve_vae(checkpoint_file=None, vae_file="auto"):
|
||||
global first_load, vae_dict, vae_list
|
||||
def resolve_vae(checkpoint_file):
|
||||
if shared.cmd_opts.vae_path is not None:
|
||||
return shared.cmd_opts.vae_path, 'from commandline argument'
|
||||
|
||||
# if vae_file argument is provided, it takes priority, but not saved
|
||||
if vae_file and vae_file not in default_vae_list:
|
||||
if not os.path.isfile(vae_file):
|
||||
print(f"VAE provided as function argument doesn't exist: {vae_file}")
|
||||
vae_file = "auto"
|
||||
# for the first load, if vae-path is provided, it takes priority, saved, and failure is reported
|
||||
if first_load and shared.cmd_opts.vae_path is not None:
|
||||
if os.path.isfile(shared.cmd_opts.vae_path):
|
||||
vae_file = shared.cmd_opts.vae_path
|
||||
shared.opts.data['sd_vae'] = get_filename(vae_file)
|
||||
else:
|
||||
print(f"VAE provided as command line argument doesn't exist: {vae_file}")
|
||||
# fallback to selector in settings, if vae selector not set to act as default fallback
|
||||
if not shared.opts.sd_vae_as_default:
|
||||
vae_file = get_vae_from_settings(vae_file)
|
||||
# vae-path cmd arg takes priority for auto
|
||||
if vae_file == "auto" and shared.cmd_opts.vae_path is not None:
|
||||
if os.path.isfile(shared.cmd_opts.vae_path):
|
||||
vae_file = shared.cmd_opts.vae_path
|
||||
print(f"Using VAE provided as command line argument: {vae_file}")
|
||||
# if still not found, try look for ".vae.pt" beside model
|
||||
model_path = os.path.splitext(checkpoint_file)[0]
|
||||
if vae_file == "auto":
|
||||
vae_file_try = model_path + ".vae.pt"
|
||||
if os.path.isfile(vae_file_try):
|
||||
vae_file = vae_file_try
|
||||
print(f"Using VAE found similar to selected model: {vae_file}")
|
||||
# if still not found, try look for ".vae.ckpt" beside model
|
||||
if vae_file == "auto":
|
||||
vae_file_try = model_path + ".vae.ckpt"
|
||||
if os.path.isfile(vae_file_try):
|
||||
vae_file = vae_file_try
|
||||
print(f"Using VAE found similar to selected model: {vae_file}")
|
||||
# if still not found, try look for ".vae.safetensors" beside model
|
||||
if vae_file == "auto":
|
||||
vae_file_try = model_path + ".vae.safetensors"
|
||||
if os.path.isfile(vae_file_try):
|
||||
vae_file = vae_file_try
|
||||
print(f"Using VAE found similar to selected model: {vae_file}")
|
||||
# No more fallbacks for auto
|
||||
if vae_file == "auto":
|
||||
vae_file = None
|
||||
# Last check, just because
|
||||
if vae_file and not os.path.exists(vae_file):
|
||||
vae_file = None
|
||||
vae_near_checkpoint = find_vae_near_checkpoint(checkpoint_file)
|
||||
if vae_near_checkpoint is not None and (shared.opts.sd_vae_as_default or shared.opts.sd_vae == "Automatic"):
|
||||
return vae_near_checkpoint, 'found near the checkpoint'
|
||||
|
||||
return vae_file
|
||||
if shared.opts.sd_vae == "None":
|
||||
return None, None
|
||||
|
||||
vae_from_options = vae_dict.get(shared.opts.sd_vae, None)
|
||||
if vae_from_options is not None:
|
||||
return vae_from_options, 'specified in settings'
|
||||
|
||||
if shared.opts.sd_vae != "Automatic":
|
||||
print(f"Couldn't find VAE named {shared.opts.sd_vae}; using None instead")
|
||||
|
||||
return None, None
|
||||
|
||||
|
||||
def load_vae(model, vae_file=None):
|
||||
global first_load, vae_dict, vae_list, loaded_vae_file
|
||||
def load_vae(model, vae_file=None, vae_source="from unknown source"):
|
||||
global vae_dict, loaded_vae_file
|
||||
# save_settings = False
|
||||
|
||||
cache_enabled = shared.opts.sd_vae_checkpoint_cache > 0
|
||||
|
@ -165,12 +120,12 @@ def load_vae(model, vae_file=None):
|
|||
if vae_file:
|
||||
if cache_enabled and vae_file in checkpoints_loaded:
|
||||
# use vae checkpoint cache
|
||||
print(f"Loading VAE weights [{get_filename(vae_file)}] from cache")
|
||||
print(f"Loading VAE weights {vae_source}: cached {get_filename(vae_file)}")
|
||||
store_base_vae(model)
|
||||
_load_vae_dict(model, checkpoints_loaded[vae_file])
|
||||
else:
|
||||
assert os.path.isfile(vae_file), f"VAE file doesn't exist: {vae_file}"
|
||||
print(f"Loading VAE weights from: {vae_file}")
|
||||
assert os.path.isfile(vae_file), f"VAE {vae_source} doesn't exist: {vae_file}"
|
||||
print(f"Loading VAE weights {vae_source}: {vae_file}")
|
||||
store_base_vae(model)
|
||||
|
||||
vae_ckpt = sd_models.read_state_dict(vae_file, map_location=shared.weight_load_location)
|
||||
|
@ -191,14 +146,12 @@ def load_vae(model, vae_file=None):
|
|||
vae_opt = get_filename(vae_file)
|
||||
if vae_opt not in vae_dict:
|
||||
vae_dict[vae_opt] = vae_file
|
||||
vae_list.append(vae_opt)
|
||||
|
||||
elif loaded_vae_file:
|
||||
restore_base_vae(model)
|
||||
|
||||
loaded_vae_file = vae_file
|
||||
|
||||
first_load = False
|
||||
|
||||
|
||||
# don't call this from outside
|
||||
def _load_vae_dict(model, vae_dict_1):
|
||||
|
@ -211,7 +164,10 @@ def clear_loaded_vae():
|
|||
loaded_vae_file = None
|
||||
|
||||
|
||||
def reload_vae_weights(sd_model=None, vae_file="auto"):
|
||||
unspecified = object()
|
||||
|
||||
|
||||
def reload_vae_weights(sd_model=None, vae_file=unspecified):
|
||||
from modules import lowvram, devices, sd_hijack
|
||||
|
||||
if not sd_model:
|
||||
|
@ -219,7 +175,11 @@ def reload_vae_weights(sd_model=None, vae_file="auto"):
|
|||
|
||||
checkpoint_info = sd_model.sd_checkpoint_info
|
||||
checkpoint_file = checkpoint_info.filename
|
||||
vae_file = resolve_vae(checkpoint_file, vae_file=vae_file)
|
||||
|
||||
if vae_file == unspecified:
|
||||
vae_file, vae_source = resolve_vae(checkpoint_file)
|
||||
else:
|
||||
vae_source = "from function argument"
|
||||
|
||||
if loaded_vae_file == vae_file:
|
||||
return
|
||||
|
@ -231,7 +191,7 @@ def reload_vae_weights(sd_model=None, vae_file="auto"):
|
|||
|
||||
sd_hijack.model_hijack.undo_hijack(sd_model)
|
||||
|
||||
load_vae(sd_model, vae_file)
|
||||
load_vae(sd_model, vae_file, vae_source)
|
||||
|
||||
sd_hijack.model_hijack.hijack(sd_model)
|
||||
script_callbacks.model_loaded_callback(sd_model)
|
||||
|
@ -239,5 +199,5 @@ def reload_vae_weights(sd_model=None, vae_file="auto"):
|
|||
if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
|
||||
sd_model.to(devices.device)
|
||||
|
||||
print("VAE Weights loaded.")
|
||||
print("VAE weights loaded.")
|
||||
return sd_model
|
||||
|
|
|
@ -74,8 +74,8 @@ parser.add_argument("--freeze-settings", action='store_true', help="disable edit
|
|||
parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default=os.path.join(script_path, 'config.json'))
|
||||
parser.add_argument("--gradio-debug", action='store_true', help="launch gradio with --debug option")
|
||||
parser.add_argument("--gradio-auth", type=str, help='set gradio authentication like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
|
||||
parser.add_argument("--gradio-img2img-tool", type=str, help='gradio image uploader tool: can be either editor for ctopping, or color-sketch for drawing', choices=["color-sketch", "editor"], default="editor")
|
||||
parser.add_argument("--gradio-inpaint-tool", type=str, choices=["sketch", "color-sketch"], default="sketch", help="gradio inpainting editor: can be either sketch to only blur/noise the input, or color-sketch to paint over it")
|
||||
parser.add_argument("--gradio-img2img-tool", type=str, help='does not do anything')
|
||||
parser.add_argument("--gradio-inpaint-tool", type=str, help="does not do anything")
|
||||
parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last")
|
||||
parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(script_path, 'styles.csv'))
|
||||
parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False)
|
||||
|
@ -83,7 +83,7 @@ parser.add_argument("--theme", type=str, help="launches the UI with light or dar
|
|||
parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False)
|
||||
parser.add_argument("--disable-console-progressbars", action='store_true', help="do not output progressbars to console", default=False)
|
||||
parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False)
|
||||
parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencoders model', default=None)
|
||||
parser.add_argument('--vae-path', type=str, help='Checkpoint to use as VAE; setting this argument disables all settings related to VAE', default=None)
|
||||
parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False)
|
||||
parser.add_argument("--api", action='store_true', help="use api=True to launch the API together with the webui (use --nowebui instead for only the API)")
|
||||
parser.add_argument("--api-auth", type=str, help='Set authentication for API like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
|
||||
|
@ -176,7 +176,7 @@ class State:
|
|||
self.interrupted = True
|
||||
|
||||
def nextjob(self):
|
||||
if opts.show_progress_every_n_steps == -1:
|
||||
if opts.live_previews_enable and opts.show_progress_every_n_steps == -1:
|
||||
self.do_set_current_image()
|
||||
|
||||
self.job_no += 1
|
||||
|
@ -224,7 +224,7 @@ class State:
|
|||
if not parallel_processing_allowed:
|
||||
return
|
||||
|
||||
if self.sampling_step - self.current_image_sampling_step >= opts.show_progress_every_n_steps and opts.show_progress_every_n_steps > 0:
|
||||
if self.sampling_step - self.current_image_sampling_step >= opts.show_progress_every_n_steps and opts.live_previews_enable:
|
||||
self.do_set_current_image()
|
||||
|
||||
def do_set_current_image(self):
|
||||
|
@ -361,6 +361,7 @@ options_templates.update(options_section(('system', "System"), {
|
|||
"memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation. Set to 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}),
|
||||
"samples_log_stdout": OptionInfo(False, "Always print all generation info to standard output"),
|
||||
"multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."),
|
||||
"print_hypernet_extra": OptionInfo(False, "Print extra hypernetwork information to console."),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('training', "Training"), {
|
||||
|
@ -373,13 +374,16 @@ options_templates.update(options_section(('training', "Training"), {
|
|||
"training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
|
||||
"training_write_csv_every": OptionInfo(500, "Save an csv containing the loss to log directory every N steps, 0 to disable"),
|
||||
"training_xattention_optimizations": OptionInfo(False, "Use cross attention optimizations while training"),
|
||||
"training_enable_tensorboard": OptionInfo(False, "Enable tensorboard logging."),
|
||||
"training_tensorboard_save_images": OptionInfo(False, "Save generated images within tensorboard."),
|
||||
"training_tensorboard_flush_every": OptionInfo(120, "How often, in seconds, to flush the pending tensorboard events and summaries to disk."),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
||||
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints),
|
||||
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
||||
"sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
||||
"sd_vae": OptionInfo("auto", "SD VAE", gr.Dropdown, lambda: {"choices": sd_vae.vae_list}, refresh=sd_vae.refresh_vae_list),
|
||||
"sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": ["Automatic", "None"] + list(sd_vae.vae_dict)}, refresh=sd_vae.refresh_vae_list),
|
||||
"sd_vae_as_default": OptionInfo(False, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
|
||||
"sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
|
||||
"sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}),
|
||||
|
@ -419,13 +423,11 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"),
|
|||
|
||||
options_templates.update(options_section(('ui', "User interface"), {
|
||||
"show_progressbar": OptionInfo(True, "Show progressbar"),
|
||||
"show_progress_every_n_steps": OptionInfo(0, "Show image creation progress every N sampling steps. Set to 0 to disable. Set to -1 to show after completion of batch.", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}),
|
||||
"show_progress_type": OptionInfo("Full", "Image creation progress preview mode", gr.Radio, {"choices": ["Full", "Approx NN", "Approx cheap"]}),
|
||||
"show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"),
|
||||
"return_grid": OptionInfo(True, "Show grid in results for web"),
|
||||
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
|
||||
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
|
||||
"add_model_name_to_info": OptionInfo(False, "Add model name to generation information"),
|
||||
"add_model_name_to_info": OptionInfo(True, "Add model name to generation information"),
|
||||
"disable_weights_auto_swap": OptionInfo(False, "When reading generation parameters from text into UI (from PNG info or pasted text), do not change the selected model/checkpoint."),
|
||||
"send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"),
|
||||
"send_size": OptionInfo(True, "Send size when sending prompt or image to another interface"),
|
||||
|
@ -440,6 +442,13 @@ options_templates.update(options_section(('ui', "User interface"), {
|
|||
'localization': OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('ui', "Live previews"), {
|
||||
"live_previews_enable": OptionInfo(True, "Show live previews of the created image"),
|
||||
"show_progress_every_n_steps": OptionInfo(10, "Show new live preview image every N sampling steps. Set to -1 to show after completion of batch.", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}),
|
||||
"show_progress_type": OptionInfo("Approx NN", "Image creation progress preview mode", gr.Radio, {"choices": ["Full", "Approx NN", "Approx cheap"]}),
|
||||
"live_preview_content": OptionInfo("Prompt", "Live preview subject", gr.Radio, {"choices": ["Combined", "Prompt", "Negative prompt"]}),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('sampler-params', "Sampler parameters"), {
|
||||
"hide_samplers": OptionInfo([], "Hide samplers in user interface (requires restart)", gr.CheckboxGroup, lambda: {"choices": [x.name for x in list_samplers()]}),
|
||||
"eta_ddim": OptionInfo(0.0, "eta (noise multiplier) for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
||||
|
@ -454,6 +463,7 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters"
|
|||
|
||||
options_templates.update(options_section((None, "Hidden options"), {
|
||||
"disabled_extensions": OptionInfo([], "Disable those extensions"),
|
||||
"sd_checkpoint_hash": OptionInfo("", "SHA256 hash of the current checkpoint"),
|
||||
}))
|
||||
|
||||
options_templates.update()
|
||||
|
|
|
@ -3,8 +3,10 @@ import numpy as np
|
|||
import PIL
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from torch.utils.data import Dataset, DataLoader, Sampler
|
||||
from torchvision import transforms
|
||||
from collections import defaultdict
|
||||
from random import shuffle, choices
|
||||
|
||||
import random
|
||||
import tqdm
|
||||
|
@ -45,12 +47,12 @@ class PersonalizedBase(Dataset):
|
|||
assert data_root, 'dataset directory not specified'
|
||||
assert os.path.isdir(data_root), "Dataset directory doesn't exist"
|
||||
assert os.listdir(data_root), "Dataset directory is empty"
|
||||
assert batch_size == 1 or not varsize, 'variable img size must have batch size 1'
|
||||
|
||||
self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]
|
||||
|
||||
self.shuffle_tags = shuffle_tags
|
||||
self.tag_drop_out = tag_drop_out
|
||||
groups = defaultdict(list)
|
||||
|
||||
print("Preparing dataset...")
|
||||
for path in tqdm.tqdm(self.image_paths):
|
||||
|
@ -103,18 +105,25 @@ class PersonalizedBase(Dataset):
|
|||
if include_cond and not (self.tag_drop_out != 0 or self.shuffle_tags):
|
||||
with devices.autocast():
|
||||
entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0)
|
||||
|
||||
groups[image.size].append(len(self.dataset))
|
||||
self.dataset.append(entry)
|
||||
del torchdata
|
||||
del latent_dist
|
||||
del latent_sample
|
||||
|
||||
self.length = len(self.dataset)
|
||||
self.groups = list(groups.values())
|
||||
assert self.length > 0, "No images have been found in the dataset."
|
||||
self.batch_size = min(batch_size, self.length)
|
||||
self.gradient_step = min(gradient_step, self.length // self.batch_size)
|
||||
self.latent_sampling_method = latent_sampling_method
|
||||
|
||||
if len(groups) > 1:
|
||||
print("Buckets:")
|
||||
for (w, h), ids in sorted(groups.items(), key=lambda x: x[0]):
|
||||
print(f" {w}x{h}: {len(ids)}")
|
||||
print()
|
||||
|
||||
def create_text(self, filename_text):
|
||||
text = random.choice(self.lines)
|
||||
tags = filename_text.split(',')
|
||||
|
@ -137,9 +146,44 @@ class PersonalizedBase(Dataset):
|
|||
entry.latent_sample = shared.sd_model.get_first_stage_encoding(entry.latent_dist).to(devices.cpu)
|
||||
return entry
|
||||
|
||||
|
||||
class GroupedBatchSampler(Sampler):
|
||||
def __init__(self, data_source: PersonalizedBase, batch_size: int):
|
||||
super().__init__(data_source)
|
||||
|
||||
n = len(data_source)
|
||||
self.groups = data_source.groups
|
||||
self.len = n_batch = n // batch_size
|
||||
expected = [len(g) / n * n_batch * batch_size for g in data_source.groups]
|
||||
self.base = [int(e) // batch_size for e in expected]
|
||||
self.n_rand_batches = nrb = n_batch - sum(self.base)
|
||||
self.probs = [e%batch_size/nrb/batch_size if nrb>0 else 0 for e in expected]
|
||||
self.batch_size = batch_size
|
||||
|
||||
def __len__(self):
|
||||
return self.len
|
||||
|
||||
def __iter__(self):
|
||||
b = self.batch_size
|
||||
|
||||
for g in self.groups:
|
||||
shuffle(g)
|
||||
|
||||
batches = []
|
||||
for g in self.groups:
|
||||
batches.extend(g[i*b:(i+1)*b] for i in range(len(g) // b))
|
||||
for _ in range(self.n_rand_batches):
|
||||
rand_group = choices(self.groups, self.probs)[0]
|
||||
batches.append(choices(rand_group, k=b))
|
||||
|
||||
shuffle(batches)
|
||||
|
||||
yield from batches
|
||||
|
||||
|
||||
class PersonalizedDataLoader(DataLoader):
|
||||
def __init__(self, dataset, latent_sampling_method="once", batch_size=1, pin_memory=False):
|
||||
super(PersonalizedDataLoader, self).__init__(dataset, shuffle=True, drop_last=True, batch_size=batch_size, pin_memory=pin_memory)
|
||||
super(PersonalizedDataLoader, self).__init__(dataset, batch_sampler=GroupedBatchSampler(dataset, batch_size), pin_memory=pin_memory)
|
||||
if latent_sampling_method == "random":
|
||||
self.collate_fn = collate_wrapper_random
|
||||
else:
|
||||
|
|
|
@ -76,10 +76,10 @@ def insert_image_data_embed(image, data):
|
|||
next_size = data_np_low.shape[0] + (h-(data_np_low.shape[0] % h))
|
||||
next_size = next_size + ((h*d)-(next_size % (h*d)))
|
||||
|
||||
data_np_low.resize(next_size)
|
||||
data_np_low = np.resize(data_np_low, next_size)
|
||||
data_np_low = data_np_low.reshape((h, -1, d))
|
||||
|
||||
data_np_high.resize(next_size)
|
||||
data_np_high = np.resize(data_np_high, next_size)
|
||||
data_np_high = data_np_high.reshape((h, -1, d))
|
||||
|
||||
edge_style = list(data['string_to_param'].values())[0].cpu().detach().numpy().tolist()[0][:1024]
|
||||
|
|
|
@ -2,7 +2,7 @@ import datetime
|
|||
import json
|
||||
import os
|
||||
|
||||
saved_params_shared = {"model_name", "model_hash", "initial_step", "num_of_dataset_images", "learn_rate", "batch_size", "data_root", "log_directory", "training_width", "training_height", "steps", "create_image_every", "template_file"}
|
||||
saved_params_shared = {"model_name", "model_hash", "initial_step", "num_of_dataset_images", "learn_rate", "batch_size", "clip_grad_mode", "clip_grad_value", "gradient_step", "data_root", "log_directory", "training_width", "training_height", "steps", "create_image_every", "template_file"}
|
||||
saved_params_ti = {"embedding_name", "num_vectors_per_token", "save_embedding_every", "save_image_with_stored_embedding"}
|
||||
saved_params_hypernet = {"hypernetwork_name", "layer_structure", "activation_func", "weight_init", "add_layer_norm", "use_dropout", "save_hypernetwork_every"}
|
||||
saved_params_all = saved_params_shared | saved_params_ti | saved_params_hypernet
|
||||
|
|
|
@ -135,7 +135,8 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre
|
|||
params.process_caption_deepbooru = process_caption_deepbooru
|
||||
params.preprocess_txt_action = preprocess_txt_action
|
||||
|
||||
for index, imagefile in enumerate(tqdm.tqdm(files)):
|
||||
pbar = tqdm.tqdm(files)
|
||||
for index, imagefile in enumerate(pbar):
|
||||
params.subindex = 0
|
||||
filename = os.path.join(src, imagefile)
|
||||
try:
|
||||
|
@ -143,6 +144,10 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre
|
|||
except Exception:
|
||||
continue
|
||||
|
||||
description = f"Preprocessing [Image {index}/{len(files)}]"
|
||||
pbar.set_description(description)
|
||||
shared.state.textinfo = description
|
||||
|
||||
params.src = filename
|
||||
|
||||
existing_caption = None
|
||||
|
|
|
@ -9,8 +9,11 @@ import tqdm
|
|||
import html
|
||||
import datetime
|
||||
import csv
|
||||
import safetensors.torch
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image, PngImagePlugin
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers
|
||||
import modules.textual_inversion.dataset
|
||||
|
@ -150,6 +153,8 @@ class EmbeddingDatabase:
|
|||
name = data.get('name', name)
|
||||
elif ext in ['.BIN', '.PT']:
|
||||
data = torch.load(path, map_location="cpu")
|
||||
elif ext in ['.SAFETENSORS']:
|
||||
data = safetensors.torch.load_file(path, device="cpu")
|
||||
else:
|
||||
return
|
||||
|
||||
|
@ -245,11 +250,14 @@ def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'):
|
|||
with devices.autocast():
|
||||
cond_model([""]) # will send cond model to GPU if lowvram/medvram is active
|
||||
|
||||
embedded = cond_model.encode_embedding_init_text(init_text, num_vectors_per_token)
|
||||
#cond_model expects at least some text, so we provide '*' as backup.
|
||||
embedded = cond_model.encode_embedding_init_text(init_text or '*', num_vectors_per_token)
|
||||
vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device)
|
||||
|
||||
for i in range(num_vectors_per_token):
|
||||
vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token]
|
||||
#Only copy if we provided an init_text, otherwise keep vectors as zeros
|
||||
if init_text:
|
||||
for i in range(num_vectors_per_token):
|
||||
vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token]
|
||||
|
||||
# Remove illegal characters from name.
|
||||
name = "".join( x for x in name if (x.isalnum() or x in "._- "))
|
||||
|
@ -288,6 +296,30 @@ def write_loss(log_directory, filename, step, epoch_len, values):
|
|||
**values,
|
||||
})
|
||||
|
||||
def tensorboard_setup(log_directory):
|
||||
os.makedirs(os.path.join(log_directory, "tensorboard"), exist_ok=True)
|
||||
return SummaryWriter(
|
||||
log_dir=os.path.join(log_directory, "tensorboard"),
|
||||
flush_secs=shared.opts.training_tensorboard_flush_every)
|
||||
|
||||
def tensorboard_add(tensorboard_writer, loss, global_step, step, learn_rate, epoch_num):
|
||||
tensorboard_add_scaler(tensorboard_writer, "Loss/train", loss, global_step)
|
||||
tensorboard_add_scaler(tensorboard_writer, f"Loss/train/epoch-{epoch_num}", loss, step)
|
||||
tensorboard_add_scaler(tensorboard_writer, "Learn rate/train", learn_rate, global_step)
|
||||
tensorboard_add_scaler(tensorboard_writer, f"Learn rate/train/epoch-{epoch_num}", learn_rate, step)
|
||||
|
||||
def tensorboard_add_scaler(tensorboard_writer, tag, value, step):
|
||||
tensorboard_writer.add_scalar(tag=tag,
|
||||
scalar_value=value, global_step=step)
|
||||
|
||||
def tensorboard_add_image(tensorboard_writer, tag, pil_image, step):
|
||||
# Convert a pil image to a torch tensor
|
||||
img_tensor = torch.as_tensor(np.array(pil_image, copy=True))
|
||||
img_tensor = img_tensor.view(pil_image.size[1], pil_image.size[0],
|
||||
len(pil_image.getbands()))
|
||||
img_tensor = img_tensor.permute((2, 0, 1))
|
||||
|
||||
tensorboard_writer.add_image(tag, img_tensor, global_step=step)
|
||||
|
||||
def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_model_every, create_image_every, log_directory, name="embedding"):
|
||||
assert model_name, f"{name} not selected"
|
||||
|
@ -366,13 +398,16 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
|
|||
# dataset loading may take a while, so input validations and early returns should be done before this
|
||||
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
||||
old_parallel_processing_allowed = shared.parallel_processing_allowed
|
||||
|
||||
if shared.opts.training_enable_tensorboard:
|
||||
tensorboard_writer = tensorboard_setup(log_directory)
|
||||
|
||||
pin_memory = shared.opts.pin_memory
|
||||
|
||||
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method, varsize=varsize)
|
||||
|
||||
if shared.opts.save_training_settings_to_txt:
|
||||
save_settings_to_file(log_directory, {**dict(model_name=checkpoint.model_name, model_hash=checkpoint.hash, num_of_dataset_images=len(ds), num_vectors_per_token=len(embedding.vec)), **locals()})
|
||||
save_settings_to_file(log_directory, {**dict(model_name=checkpoint.model_name, model_hash=checkpoint.shorthash, num_of_dataset_images=len(ds), num_vectors_per_token=len(embedding.vec)), **locals()})
|
||||
|
||||
latent_sampling_method = ds.latent_sampling_method
|
||||
|
||||
|
@ -473,7 +508,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
|
|||
epoch_num = embedding.step // steps_per_epoch
|
||||
epoch_step = embedding.step % steps_per_epoch
|
||||
|
||||
pbar.set_description(f"[Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}")
|
||||
description = f"Training textual inversion [Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}] loss: {loss_step:.7f}"
|
||||
pbar.set_description(description)
|
||||
shared.state.textinfo = description
|
||||
if embedding_dir is not None and steps_done % save_embedding_every == 0:
|
||||
# Before saving, change name to match current checkpoint.
|
||||
embedding_name_every = f'{embedding_name}-{steps_done}'
|
||||
|
@ -527,6 +564,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
|
|||
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
|
||||
last_saved_image += f", prompt: {preview_text}"
|
||||
|
||||
if shared.opts.training_enable_tensorboard and shared.opts.training_tensorboard_save_images:
|
||||
tensorboard_add_image(tensorboard_writer, f"Validation at epoch {epoch_num}", image, embedding.step)
|
||||
|
||||
if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:
|
||||
|
||||
last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{steps_done}.png')
|
||||
|
@ -544,7 +584,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
|
|||
|
||||
checkpoint = sd_models.select_checkpoint()
|
||||
footer_left = checkpoint.model_name
|
||||
footer_mid = '[{}]'.format(checkpoint.hash)
|
||||
footer_mid = '[{}]'.format(checkpoint.shorthash)
|
||||
footer_right = '{}v {}s'.format(vectorSize, steps_done)
|
||||
|
||||
captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right)
|
||||
|
@ -586,7 +626,7 @@ def save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, r
|
|||
old_sd_checkpoint_name = embedding.sd_checkpoint_name if hasattr(embedding, "sd_checkpoint_name") else None
|
||||
old_cached_checksum = embedding.cached_checksum if hasattr(embedding, "cached_checksum") else None
|
||||
try:
|
||||
embedding.sd_checkpoint = checkpoint.hash
|
||||
embedding.sd_checkpoint = checkpoint.shorthash
|
||||
embedding.sd_checkpoint_name = checkpoint.model_name
|
||||
if remove_cached_checksum:
|
||||
embedding.cached_checksum = None
|
||||
|
|
114
modules/ui.py
114
modules/ui.py
|
@ -795,53 +795,67 @@ def create_ui():
|
|||
|
||||
with FormRow().style(equal_height=False):
|
||||
with gr.Column(variant='panel', elem_id="img2img_settings"):
|
||||
with gr.Tabs(elem_id="mode_img2img"):
|
||||
with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab") as tab_img2img:
|
||||
init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool="editor", image_mode="RGBA").style(height=480)
|
||||
|
||||
with gr.Tabs(elem_id="mode_img2img") as tabs_img2img_mode:
|
||||
with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab"):
|
||||
init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_img2img_tool, image_mode="RGBA").style(height=480)
|
||||
with gr.TabItem('Sketch', id='img2img_sketch', elem_id="img2img_img2img_sketch_tab") as tab_sketch:
|
||||
sketch = gr.Image(label="Image for img2img", elem_id="img2img_sketch", show_label=False, source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA").style(height=480)
|
||||
|
||||
with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab"):
|
||||
init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_inpaint_tool, image_mode="RGBA").style(height=480)
|
||||
init_img_with_mask_orig = gr.State(None)
|
||||
with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab") as tab_inpaint:
|
||||
init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA").style(height=480)
|
||||
|
||||
use_color_sketch = cmd_opts.gradio_inpaint_tool == "color-sketch"
|
||||
if use_color_sketch:
|
||||
def update_orig(image, state):
|
||||
if image is not None:
|
||||
same_size = state is not None and state.size == image.size
|
||||
has_exact_match = np.any(np.all(np.array(image) == np.array(state), axis=-1))
|
||||
edited = same_size and has_exact_match
|
||||
return image if not edited or state is None else state
|
||||
with gr.TabItem('Inpaint sketch', id='inpaint_sketch', elem_id="img2img_inpaint_sketch_tab") as tab_inpaint_color:
|
||||
inpaint_color_sketch = gr.Image(label="Color sketch inpainting", show_label=False, elem_id="inpaint_sketch", source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA").style(height=480)
|
||||
inpaint_color_sketch_orig = gr.State(None)
|
||||
|
||||
init_img_with_mask.change(update_orig, [init_img_with_mask, init_img_with_mask_orig], init_img_with_mask_orig)
|
||||
def update_orig(image, state):
|
||||
if image is not None:
|
||||
same_size = state is not None and state.size == image.size
|
||||
has_exact_match = np.any(np.all(np.array(image) == np.array(state), axis=-1))
|
||||
edited = same_size and has_exact_match
|
||||
return image if not edited or state is None else state
|
||||
|
||||
init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_base")
|
||||
init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_mask")
|
||||
inpaint_color_sketch.change(update_orig, [inpaint_color_sketch, inpaint_color_sketch_orig], inpaint_color_sketch_orig)
|
||||
|
||||
with FormRow():
|
||||
mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id="img2img_mask_blur")
|
||||
mask_alpha = gr.Slider(label="Mask transparency", interactive=use_color_sketch, visible=use_color_sketch, elem_id="img2img_mask_alpha")
|
||||
with gr.TabItem('Inpaint upload', id='inpaint_upload', elem_id="img2img_inpaint_upload_tab") as tab_inpaint_upload:
|
||||
init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", elem_id="img_inpaint_base")
|
||||
init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", elem_id="img_inpaint_mask")
|
||||
|
||||
with FormRow():
|
||||
mask_mode = gr.Radio(label="Mask source", choices=["Draw mask", "Upload mask"], type="index", value="Draw mask", elem_id="mask_mode")
|
||||
inpainting_mask_invert = gr.Radio(label='Mask mode', choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", elem_id="img2img_mask_mode")
|
||||
|
||||
with FormRow():
|
||||
inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='original', type="index", elem_id="img2img_inpainting_fill")
|
||||
|
||||
with FormRow():
|
||||
with gr.Column():
|
||||
inpaint_full_res = gr.Radio(label="Inpaint area", choices=["Whole picture", "Only masked"], type="index", value="Whole picture", elem_id="img2img_inpaint_full_res")
|
||||
|
||||
with gr.Column(scale=4):
|
||||
inpaint_full_res_padding = gr.Slider(label='Only masked padding, pixels', minimum=0, maximum=256, step=4, value=32, elem_id="img2img_inpaint_full_res_padding")
|
||||
|
||||
with gr.TabItem('Batch img2img', id='batch', elem_id="img2img_batch_tab"):
|
||||
with gr.TabItem('Batch', id='batch', elem_id="img2img_batch_tab") as tab_batch:
|
||||
hidden = '<br>Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else ''
|
||||
gr.HTML(f"<p class=\"text-gray-500\">Process images in a directory on the same machine where the server is running.<br>Use an empty output directory to save pictures normally instead of writing to the output directory.{hidden}</p>")
|
||||
img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, elem_id="img2img_batch_input_dir")
|
||||
img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir")
|
||||
|
||||
with FormGroup(elem_id="inpaint_controls", visible=False) as inpaint_controls:
|
||||
with FormRow():
|
||||
mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id="img2img_mask_blur")
|
||||
mask_alpha = gr.Slider(label="Mask transparency", visible=False, elem_id="img2img_mask_alpha")
|
||||
|
||||
with FormRow():
|
||||
inpainting_mask_invert = gr.Radio(label='Mask mode', choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", elem_id="img2img_mask_mode")
|
||||
|
||||
with FormRow():
|
||||
inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='original', type="index", elem_id="img2img_inpainting_fill")
|
||||
|
||||
with FormRow():
|
||||
with gr.Column():
|
||||
inpaint_full_res = gr.Radio(label="Inpaint area", choices=["Whole picture", "Only masked"], type="index", value="Whole picture", elem_id="img2img_inpaint_full_res")
|
||||
|
||||
with gr.Column(scale=4):
|
||||
inpaint_full_res_padding = gr.Slider(label='Only masked padding, pixels', minimum=0, maximum=256, step=4, value=32, elem_id="img2img_inpaint_full_res_padding")
|
||||
|
||||
def select_img2img_tab(tab):
|
||||
return gr.update(visible=tab in [2, 3, 4]), gr.update(visible=tab == 3),
|
||||
|
||||
for i, elem in enumerate([tab_img2img, tab_sketch, tab_inpaint, tab_inpaint_color, tab_inpaint_upload, tab_batch]):
|
||||
elem.select(
|
||||
fn=lambda tab=i: select_img2img_tab(tab),
|
||||
inputs=[],
|
||||
outputs=[inpaint_controls, mask_alpha],
|
||||
)
|
||||
|
||||
with FormRow():
|
||||
resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize")
|
||||
|
||||
|
@ -900,20 +914,6 @@ def create_ui():
|
|||
]
|
||||
)
|
||||
|
||||
mask_mode.change(
|
||||
lambda mode, img: {
|
||||
init_img_with_mask: gr_show(mode == 0),
|
||||
init_img_inpaint: gr_show(mode == 1),
|
||||
init_mask_inpaint: gr_show(mode == 1),
|
||||
},
|
||||
inputs=[mask_mode, init_img_with_mask],
|
||||
outputs=[
|
||||
init_img_with_mask,
|
||||
init_img_inpaint,
|
||||
init_mask_inpaint,
|
||||
],
|
||||
)
|
||||
|
||||
img2img_args = dict(
|
||||
fn=wrap_gradio_gpu_call(modules.img2img.img2img, extra_outputs=[None, '', '']),
|
||||
_js="submit_img2img",
|
||||
|
@ -924,11 +924,12 @@ def create_ui():
|
|||
img2img_prompt_style,
|
||||
img2img_prompt_style2,
|
||||
init_img,
|
||||
sketch,
|
||||
init_img_with_mask,
|
||||
init_img_with_mask_orig,
|
||||
inpaint_color_sketch,
|
||||
inpaint_color_sketch_orig,
|
||||
init_img_inpaint,
|
||||
init_mask_inpaint,
|
||||
mask_mode,
|
||||
steps,
|
||||
sampler_index,
|
||||
mask_blur,
|
||||
|
@ -1129,7 +1130,7 @@ def create_ui():
|
|||
with gr.Column(variant='panel'):
|
||||
gr.HTML(value="<p>A merger of the two checkpoints will be generated in your <b>checkpoint</b> directory.</p>")
|
||||
|
||||
with gr.Row():
|
||||
with FormRow():
|
||||
primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)")
|
||||
create_refresh_button(primary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_A")
|
||||
|
||||
|
@ -1143,11 +1144,13 @@ def create_ui():
|
|||
interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3, elem_id="modelmerger_interp_amount")
|
||||
interp_method = gr.Radio(choices=["Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method", elem_id="modelmerger_interp_method")
|
||||
|
||||
with gr.Row():
|
||||
with FormRow():
|
||||
checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="ckpt", label="Checkpoint format", elem_id="modelmerger_checkpoint_format")
|
||||
save_as_half = gr.Checkbox(value=False, label="Save as float16", elem_id="modelmerger_save_as_half")
|
||||
|
||||
modelmerger_merge = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary')
|
||||
config_source = gr.Radio(choices=["A, B or C", "B", "C", "Don't"], value="A, B or C", label="Copy config from", type="index", elem_id="modelmerger_config_method")
|
||||
|
||||
modelmerger_merge = gr.Button(elem_id="modelmerger_merge", value="Merge", variant='primary')
|
||||
|
||||
with gr.Column(variant='panel'):
|
||||
submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False)
|
||||
|
@ -1703,6 +1706,7 @@ def create_ui():
|
|||
save_as_half,
|
||||
custom_name,
|
||||
checkpoint_format,
|
||||
config_source,
|
||||
],
|
||||
outputs=[
|
||||
submit_result,
|
||||
|
@ -1837,4 +1841,6 @@ xformers: {xformers_version}
|
|||
gradio: {gr.__version__}
|
||||
•
|
||||
commit: <a href="https://github.com/AUTOMATIC1111/stable-diffusion-webui/commit/{commit}">{short_commit}</a>
|
||||
•
|
||||
checkpoint: <a id="sd_checkpoint_hash">N/A</a>
|
||||
"""
|
||||
|
|
|
@ -52,7 +52,7 @@ def check_progress_call(id_part):
|
|||
image = gr.update(visible=False)
|
||||
preview_visibility = gr.update(visible=False)
|
||||
|
||||
if opts.show_progress_every_n_steps != 0:
|
||||
if opts.live_previews_enable:
|
||||
shared.state.set_current_image()
|
||||
image = shared.state.current_image
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
function gradioApp() {
|
||||
const gradioShadowRoot = document.getElementsByTagName('gradio-app')[0].shadowRoot
|
||||
const elems = document.getElementsByTagName('gradio-app')
|
||||
const gradioShadowRoot = elems.length == 0 ? null : elems[0].shadowRoot
|
||||
return !!gradioShadowRoot ? gradioShadowRoot : document;
|
||||
}
|
||||
|
||||
|
@ -13,6 +14,7 @@ function get_uiCurrentTabContent() {
|
|||
|
||||
uiUpdateCallbacks = []
|
||||
uiTabChangeCallbacks = []
|
||||
optionsChangedCallbacks = []
|
||||
let uiCurrentTab = null
|
||||
|
||||
function onUiUpdate(callback){
|
||||
|
@ -21,6 +23,9 @@ function onUiUpdate(callback){
|
|||
function onUiTabChange(callback){
|
||||
uiTabChangeCallbacks.push(callback)
|
||||
}
|
||||
function onOptionsChanged(callback){
|
||||
optionsChangedCallbacks.push(callback)
|
||||
}
|
||||
|
||||
function runCallback(x, m){
|
||||
try {
|
||||
|
|
|
@ -146,11 +146,7 @@ class Script(scripts.Script):
|
|||
else:
|
||||
args = {"prompt": line}
|
||||
|
||||
n_iter = args.get("n_iter", 1)
|
||||
if n_iter != 1:
|
||||
job_count += n_iter
|
||||
else:
|
||||
job_count += 1
|
||||
job_count += args.get("n_iter", p.n_iter)
|
||||
|
||||
jobs.append(args)
|
||||
|
||||
|
|
|
@ -125,24 +125,21 @@ def apply_upscale_latent_space(p, x, xs):
|
|||
|
||||
|
||||
def find_vae(name: str):
|
||||
if name.lower() in ['auto', 'none']:
|
||||
return name
|
||||
if name.lower() in ['auto', 'automatic']:
|
||||
return modules.sd_vae.unspecified
|
||||
if name.lower() == 'none':
|
||||
return None
|
||||
else:
|
||||
vae_path = os.path.abspath(os.path.join(paths.models_path, 'VAE'))
|
||||
found = glob.glob(os.path.join(vae_path, f'**/{name}.*pt'), recursive=True)
|
||||
if found:
|
||||
return found[0]
|
||||
choices = [x for x in sorted(modules.sd_vae.vae_dict, key=lambda x: len(x)) if name.lower().strip() in x.lower()]
|
||||
if len(choices) == 0:
|
||||
print(f"No VAE found for {name}; using automatic")
|
||||
return modules.sd_vae.unspecified
|
||||
else:
|
||||
return 'auto'
|
||||
return modules.sd_vae.vae_dict[choices[0]]
|
||||
|
||||
|
||||
def apply_vae(p, x, xs):
|
||||
if x.lower().strip() == 'none':
|
||||
modules.sd_vae.reload_vae_weights(shared.sd_model, vae_file='None')
|
||||
else:
|
||||
found = find_vae(x)
|
||||
if found:
|
||||
v = modules.sd_vae.reload_vae_weights(shared.sd_model, vae_file=found)
|
||||
modules.sd_vae.reload_vae_weights(shared.sd_model, vae_file=find_vae(x))
|
||||
|
||||
|
||||
def apply_styles(p: StableDiffusionProcessingTxt2Img, x: str, _):
|
||||
|
@ -271,7 +268,9 @@ class SharedSettingsStackHelper(object):
|
|||
|
||||
def __exit__(self, exc_type, exc_value, tb):
|
||||
modules.sd_models.reload_model_weights(self.model)
|
||||
modules.sd_vae.reload_vae_weights(self.model, vae_file=find_vae(self.vae))
|
||||
|
||||
opts.data["sd_vae"] = self.vae
|
||||
modules.sd_vae.reload_vae_weights(self.model)
|
||||
|
||||
hypernetwork.load_hypernetwork(self.hypernetwork)
|
||||
hypernetwork.apply_strength()
|
||||
|
|
|
@ -557,7 +557,9 @@ canvas[key="mask"] {
|
|||
}
|
||||
|
||||
#img2img_image, #img2img_image > .h-60, #img2img_image > .h-60 > div, #img2img_image > .h-60 > div > img,
|
||||
img2maskimg, #img2maskimg > .h-60, #img2maskimg > .h-60 > div, #img2maskimg > .h-60 > div > img
|
||||
#img2img_sketch, #img2img_sketch > .h-60, #img2img_sketch > .h-60 > div, #img2img_sketch > .h-60 > div > img,
|
||||
#img2maskimg, #img2maskimg > .h-60, #img2maskimg > .h-60 > div, #img2maskimg > .h-60 > div > img,
|
||||
#inpaint_sketch, #inpaint_sketch > .h-60, #inpaint_sketch > .h-60 > div, #inpaint_sketch > .h-60 > div > img
|
||||
{
|
||||
height: 480px !important;
|
||||
max-height: 480px !important;
|
||||
|
|
2
webui.py
2
webui.py
|
@ -78,6 +78,8 @@ def initialize():
|
|||
print("Stable diffusion model failed to load, exiting", file=sys.stderr)
|
||||
exit(1)
|
||||
|
||||
shared.opts.data["sd_model_checkpoint"] = shared.sd_model.sd_checkpoint_info.title
|
||||
|
||||
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()))
|
||||
shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
|
||||
shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
|
||||
|
|
Loading…
Reference in New Issue
Block a user