Merge remote-tracking branch 'upstream/master' into PowerShell
This commit is contained in:
commit
010787ad9c
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -32,3 +32,4 @@ notification.mp3
|
||||||
/extensions
|
/extensions
|
||||||
/test/stdout.txt
|
/test/stdout.txt
|
||||||
/test/stderr.txt
|
/test/stderr.txt
|
||||||
|
/cache.json
|
||||||
|
|
|
@ -92,6 +92,7 @@ titles = {
|
||||||
"Weighted sum": "Result = A * (1 - M) + B * M",
|
"Weighted sum": "Result = A * (1 - M) + B * M",
|
||||||
"Add difference": "Result = A + (B - C) * M",
|
"Add difference": "Result = A + (B - C) * M",
|
||||||
|
|
||||||
|
"Initialization text": "If the number of tokens is more than the number of vectors, some may be skipped.\nLeave the textbox empty to start with zeroed out vectors",
|
||||||
"Learning rate": "How fast should training go. Low values will take longer to train, high values may fail to converge (not generate accurate results) and/or may break the embedding (This has happened if you see Loss: nan in the training info textbox. If this happens, you need to manually restore your embedding from an older not-broken backup).\n\nYou can set a single numeric value, or multiple learning rates using the syntax:\n\n rate_1:max_steps_1, rate_2:max_steps_2, ...\n\nEG: 0.005:100, 1e-3:1000, 1e-5\n\nWill train with rate of 0.005 for first 100 steps, then 1e-3 until 1000 steps, then 1e-5 for all remaining steps.",
|
"Learning rate": "How fast should training go. Low values will take longer to train, high values may fail to converge (not generate accurate results) and/or may break the embedding (This has happened if you see Loss: nan in the training info textbox. If this happens, you need to manually restore your embedding from an older not-broken backup).\n\nYou can set a single numeric value, or multiple learning rates using the syntax:\n\n rate_1:max_steps_1, rate_2:max_steps_2, ...\n\nEG: 0.005:100, 1e-3:1000, 1e-5\n\nWill train with rate of 0.005 for first 100 steps, then 1e-3 until 1000 steps, then 1e-5 for all remaining steps.",
|
||||||
|
|
||||||
"Clip skip": "Early stopping parameter for CLIP model; 1 is stop at last layer as usual, 2 is stop at penultimate layer, etc.",
|
"Clip skip": "Early stopping parameter for CLIP model; 1 is stop at last layer as usual, 2 is stop at penultimate layer, etc.",
|
||||||
|
|
|
@ -151,6 +151,7 @@ function showGalleryImage() {
|
||||||
e.addEventListener('mousedown', function (evt) {
|
e.addEventListener('mousedown', function (evt) {
|
||||||
if(!opts.js_modal_lightbox || evt.button != 0) return;
|
if(!opts.js_modal_lightbox || evt.button != 0) return;
|
||||||
modalZoomSet(gradioApp().getElementById('modalImage'), opts.js_modal_lightbox_initially_zoomed)
|
modalZoomSet(gradioApp().getElementById('modalImage'), opts.js_modal_lightbox_initially_zoomed)
|
||||||
|
evt.preventDefault()
|
||||||
showModal(evt)
|
showModal(evt)
|
||||||
}, true);
|
}, true);
|
||||||
}
|
}
|
||||||
|
|
|
@ -54,7 +54,7 @@ function switch_to_img2img(){
|
||||||
|
|
||||||
function switch_to_inpaint(){
|
function switch_to_inpaint(){
|
||||||
gradioApp().querySelector('#tabs').querySelectorAll('button')[1].click();
|
gradioApp().querySelector('#tabs').querySelectorAll('button')[1].click();
|
||||||
gradioApp().getElementById('mode_img2img').querySelectorAll('button')[1].click();
|
gradioApp().getElementById('mode_img2img').querySelectorAll('button')[2].click();
|
||||||
|
|
||||||
return args_to_array(arguments);
|
return args_to_array(arguments);
|
||||||
}
|
}
|
||||||
|
@ -143,14 +143,6 @@ function confirm_clear_prompt(prompt, negative_prompt) {
|
||||||
|
|
||||||
|
|
||||||
opts = {}
|
opts = {}
|
||||||
function apply_settings(jsdata){
|
|
||||||
console.log(jsdata)
|
|
||||||
|
|
||||||
opts = JSON.parse(jsdata)
|
|
||||||
|
|
||||||
return jsdata
|
|
||||||
}
|
|
||||||
|
|
||||||
onUiUpdate(function(){
|
onUiUpdate(function(){
|
||||||
if(Object.keys(opts).length != 0) return;
|
if(Object.keys(opts).length != 0) return;
|
||||||
|
|
||||||
|
@ -160,7 +152,7 @@ onUiUpdate(function(){
|
||||||
textarea = json_elem.querySelector('textarea')
|
textarea = json_elem.querySelector('textarea')
|
||||||
jsdata = textarea.value
|
jsdata = textarea.value
|
||||||
opts = JSON.parse(jsdata)
|
opts = JSON.parse(jsdata)
|
||||||
|
executeCallbacks(optionsChangedCallbacks);
|
||||||
|
|
||||||
Object.defineProperty(textarea, 'value', {
|
Object.defineProperty(textarea, 'value', {
|
||||||
set: function(newValue) {
|
set: function(newValue) {
|
||||||
|
@ -171,6 +163,8 @@ onUiUpdate(function(){
|
||||||
if (oldValue != newValue) {
|
if (oldValue != newValue) {
|
||||||
opts = JSON.parse(textarea.value)
|
opts = JSON.parse(textarea.value)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
executeCallbacks(optionsChangedCallbacks);
|
||||||
},
|
},
|
||||||
get: function() {
|
get: function() {
|
||||||
var valueProp = Object.getOwnPropertyDescriptor(HTMLTextAreaElement.prototype, 'value');
|
var valueProp = Object.getOwnPropertyDescriptor(HTMLTextAreaElement.prototype, 'value');
|
||||||
|
@ -201,6 +195,19 @@ onUiUpdate(function(){
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
|
onOptionsChanged(function(){
|
||||||
|
elem = gradioApp().getElementById('sd_checkpoint_hash')
|
||||||
|
sd_checkpoint_hash = opts.sd_checkpoint_hash || ""
|
||||||
|
shorthash = sd_checkpoint_hash.substr(0,10)
|
||||||
|
|
||||||
|
if(elem && elem.textContent != shorthash){
|
||||||
|
elem.textContent = shorthash
|
||||||
|
elem.title = sd_checkpoint_hash
|
||||||
|
elem.href = "https://google.com/search?q=" + sd_checkpoint_hash
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
let txt2img_textarea, img2img_textarea = undefined;
|
let txt2img_textarea, img2img_textarea = undefined;
|
||||||
let wait_time = 800
|
let wait_time = 800
|
||||||
let token_timeout;
|
let token_timeout;
|
||||||
|
|
|
@ -286,7 +286,7 @@ class Api:
|
||||||
# copy from check_progress_call of ui.py
|
# copy from check_progress_call of ui.py
|
||||||
|
|
||||||
if shared.state.job_count == 0:
|
if shared.state.job_count == 0:
|
||||||
return ProgressResponse(progress=0, eta_relative=0, state=shared.state.dict())
|
return ProgressResponse(progress=0, eta_relative=0, state=shared.state.dict(), textinfo=shared.state.textinfo)
|
||||||
|
|
||||||
# avoid dividing zero
|
# avoid dividing zero
|
||||||
progress = 0.01
|
progress = 0.01
|
||||||
|
@ -308,7 +308,7 @@ class Api:
|
||||||
if shared.state.current_image and not req.skip_current_image:
|
if shared.state.current_image and not req.skip_current_image:
|
||||||
current_image = encode_pil_to_base64(shared.state.current_image)
|
current_image = encode_pil_to_base64(shared.state.current_image)
|
||||||
|
|
||||||
return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image)
|
return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image, textinfo=shared.state.textinfo)
|
||||||
|
|
||||||
def interrogateapi(self, interrogatereq: InterrogateRequest):
|
def interrogateapi(self, interrogatereq: InterrogateRequest):
|
||||||
image_b64 = interrogatereq.image
|
image_b64 = interrogatereq.image
|
||||||
|
@ -371,7 +371,7 @@ class Api:
|
||||||
return upscalers
|
return upscalers
|
||||||
|
|
||||||
def get_sd_models(self):
|
def get_sd_models(self):
|
||||||
return [{"title":x.title, "model_name":x.model_name, "hash":x.hash, "filename": x.filename, "config": find_checkpoint_config(x)} for x in checkpoints_list.values()]
|
return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config(x)} for x in checkpoints_list.values()]
|
||||||
|
|
||||||
def get_hypernetworks(self):
|
def get_hypernetworks(self):
|
||||||
return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks]
|
return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks]
|
||||||
|
|
|
@ -168,6 +168,7 @@ class ProgressResponse(BaseModel):
|
||||||
eta_relative: float = Field(title="ETA in secs")
|
eta_relative: float = Field(title="ETA in secs")
|
||||||
state: dict = Field(title="State", description="The current state snapshot")
|
state: dict = Field(title="State", description="The current state snapshot")
|
||||||
current_image: str = Field(default=None, title="Current image", description="The current image in base64 format. opts.show_progress_every_n_steps is required for this to work.")
|
current_image: str = Field(default=None, title="Current image", description="The current image in base64 format. opts.show_progress_every_n_steps is required for this to work.")
|
||||||
|
textinfo: str = Field(default=None, title="Info text", description="Info text used by WebUI.")
|
||||||
|
|
||||||
class InterrogateRequest(BaseModel):
|
class InterrogateRequest(BaseModel):
|
||||||
image: str = Field(default="", title="Image", description="Image to work on, must be a Base64 string containing the image's data.")
|
image: str = Field(default="", title="Image", description="Image to work on, must be a Base64 string containing the image's data.")
|
||||||
|
@ -223,7 +224,8 @@ class UpscalerItem(BaseModel):
|
||||||
class SDModelItem(BaseModel):
|
class SDModelItem(BaseModel):
|
||||||
title: str = Field(title="Title")
|
title: str = Field(title="Title")
|
||||||
model_name: str = Field(title="Model Name")
|
model_name: str = Field(title="Model Name")
|
||||||
hash: str = Field(title="Hash")
|
hash: Optional[str] = Field(title="Short hash")
|
||||||
|
sha256: Optional[str] = Field(title="sha256 hash")
|
||||||
filename: str = Field(title="Filename")
|
filename: str = Field(title="Filename")
|
||||||
config: str = Field(title="Config file")
|
config: str = Field(title="Config file")
|
||||||
|
|
||||||
|
|
|
@ -3,6 +3,7 @@ import math
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
|
import shutil
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
@ -248,7 +249,32 @@ def run_pnginfo(image):
|
||||||
return '', geninfo, info
|
return '', geninfo, info
|
||||||
|
|
||||||
|
|
||||||
def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format):
|
def create_config(ckpt_result, config_source, a, b, c):
|
||||||
|
def config(x):
|
||||||
|
return sd_models.find_checkpoint_config(x) if x else None
|
||||||
|
|
||||||
|
if config_source == 0:
|
||||||
|
cfg = config(a) or config(b) or config(c)
|
||||||
|
elif config_source == 1:
|
||||||
|
cfg = config(b)
|
||||||
|
elif config_source == 2:
|
||||||
|
cfg = config(c)
|
||||||
|
else:
|
||||||
|
cfg = None
|
||||||
|
|
||||||
|
if cfg is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
filename, _ = os.path.splitext(ckpt_result)
|
||||||
|
checkpoint_filename = filename + ".yaml"
|
||||||
|
|
||||||
|
print("Copying config:")
|
||||||
|
print(" from:", cfg)
|
||||||
|
print(" to:", checkpoint_filename)
|
||||||
|
shutil.copyfile(cfg, checkpoint_filename)
|
||||||
|
|
||||||
|
|
||||||
|
def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source):
|
||||||
shared.state.begin()
|
shared.state.begin()
|
||||||
shared.state.job = 'model-merge'
|
shared.state.job = 'model-merge'
|
||||||
|
|
||||||
|
@ -356,6 +382,8 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
|
||||||
|
|
||||||
sd_models.list_models()
|
sd_models.list_models()
|
||||||
|
|
||||||
|
create_config(output_modelname, config_source, primary_model_info, secondary_model_info, tertiary_model_info)
|
||||||
|
|
||||||
print("Checkpoint saved.")
|
print("Checkpoint saved.")
|
||||||
shared.state.textinfo = "Checkpoint saved to " + output_modelname
|
shared.state.textinfo = "Checkpoint saved to " + output_modelname
|
||||||
shared.state.end()
|
shared.state.end()
|
||||||
|
|
|
@ -7,7 +7,7 @@ from pathlib import Path
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
from modules.shared import script_path
|
from modules.shared import script_path
|
||||||
from modules import shared, ui_tempdir
|
from modules import shared, ui_tempdir, script_callbacks
|
||||||
import tempfile
|
import tempfile
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
@ -298,6 +298,7 @@ def connect_paste(button, paste_fields, input_comp, jsfunc=None):
|
||||||
prompt = file.read()
|
prompt = file.read()
|
||||||
|
|
||||||
params = parse_generation_parameters(prompt)
|
params = parse_generation_parameters(prompt)
|
||||||
|
script_callbacks.infotext_pasted_callback(prompt, params)
|
||||||
res = []
|
res = []
|
||||||
|
|
||||||
for output, key in paste_fields:
|
for output, key in paste_fields:
|
||||||
|
|
84
modules/hashes.py
Normal file
84
modules/hashes.py
Normal file
|
@ -0,0 +1,84 @@
|
||||||
|
import hashlib
|
||||||
|
import json
|
||||||
|
import os.path
|
||||||
|
|
||||||
|
import filelock
|
||||||
|
|
||||||
|
|
||||||
|
cache_filename = "cache.json"
|
||||||
|
cache_data = None
|
||||||
|
|
||||||
|
|
||||||
|
def dump_cache():
|
||||||
|
with filelock.FileLock(cache_filename+".lock"):
|
||||||
|
with open(cache_filename, "w", encoding="utf8") as file:
|
||||||
|
json.dump(cache_data, file, indent=4)
|
||||||
|
|
||||||
|
|
||||||
|
def cache(subsection):
|
||||||
|
global cache_data
|
||||||
|
|
||||||
|
if cache_data is None:
|
||||||
|
with filelock.FileLock(cache_filename+".lock"):
|
||||||
|
if not os.path.isfile(cache_filename):
|
||||||
|
cache_data = {}
|
||||||
|
else:
|
||||||
|
with open(cache_filename, "r", encoding="utf8") as file:
|
||||||
|
cache_data = json.load(file)
|
||||||
|
|
||||||
|
s = cache_data.get(subsection, {})
|
||||||
|
cache_data[subsection] = s
|
||||||
|
|
||||||
|
return s
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_sha256(filename):
|
||||||
|
hash_sha256 = hashlib.sha256()
|
||||||
|
|
||||||
|
with open(filename, "rb") as f:
|
||||||
|
for chunk in iter(lambda: f.read(4096), b""):
|
||||||
|
hash_sha256.update(chunk)
|
||||||
|
|
||||||
|
return hash_sha256.hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
def sha256_from_cache(filename, title):
|
||||||
|
hashes = cache("hashes")
|
||||||
|
ondisk_mtime = os.path.getmtime(filename)
|
||||||
|
|
||||||
|
if title not in hashes:
|
||||||
|
return None
|
||||||
|
|
||||||
|
cached_sha256 = hashes[title].get("sha256", None)
|
||||||
|
cached_mtime = hashes[title].get("mtime", 0)
|
||||||
|
|
||||||
|
if ondisk_mtime > cached_mtime or cached_sha256 is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return cached_sha256
|
||||||
|
|
||||||
|
|
||||||
|
def sha256(filename, title):
|
||||||
|
hashes = cache("hashes")
|
||||||
|
|
||||||
|
sha256_value = sha256_from_cache(filename, title)
|
||||||
|
if sha256_value is not None:
|
||||||
|
return sha256_value
|
||||||
|
|
||||||
|
print(f"Calculating sha256 for {filename}: ", end='')
|
||||||
|
sha256_value = calculate_sha256(filename)
|
||||||
|
print(f"{sha256_value}")
|
||||||
|
|
||||||
|
hashes[title] = {
|
||||||
|
"mtime": os.path.getmtime(filename),
|
||||||
|
"sha256": sha256_value,
|
||||||
|
}
|
||||||
|
|
||||||
|
dump_cache()
|
||||||
|
|
||||||
|
return sha256_value
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -12,7 +12,7 @@ import torch
|
||||||
import tqdm
|
import tqdm
|
||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
from ldm.util import default
|
from ldm.util import default
|
||||||
from modules import devices, processing, sd_models, shared, sd_samplers
|
from modules import devices, processing, sd_models, shared, sd_samplers, hashes
|
||||||
from modules.textual_inversion import textual_inversion, logging
|
from modules.textual_inversion import textual_inversion, logging
|
||||||
from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
||||||
from torch import einsum
|
from torch import einsum
|
||||||
|
@ -24,7 +24,6 @@ from statistics import stdev, mean
|
||||||
|
|
||||||
optimizer_dict = {optim_name : cls_obj for optim_name, cls_obj in inspect.getmembers(torch.optim, inspect.isclass) if optim_name != "Optimizer"}
|
optimizer_dict = {optim_name : cls_obj for optim_name, cls_obj in inspect.getmembers(torch.optim, inspect.isclass) if optim_name != "Optimizer"}
|
||||||
|
|
||||||
|
|
||||||
class HypernetworkModule(torch.nn.Module):
|
class HypernetworkModule(torch.nn.Module):
|
||||||
multiplier = 1.0
|
multiplier = 1.0
|
||||||
activation_dict = {
|
activation_dict = {
|
||||||
|
@ -226,7 +225,7 @@ class Hypernetwork:
|
||||||
|
|
||||||
torch.save(state_dict, filename)
|
torch.save(state_dict, filename)
|
||||||
if shared.opts.save_optimizer_state and self.optimizer_state_dict:
|
if shared.opts.save_optimizer_state and self.optimizer_state_dict:
|
||||||
optimizer_saved_dict['hash'] = sd_models.model_hash(filename)
|
optimizer_saved_dict['hash'] = self.shorthash()
|
||||||
optimizer_saved_dict['optimizer_state_dict'] = self.optimizer_state_dict
|
optimizer_saved_dict['optimizer_state_dict'] = self.optimizer_state_dict
|
||||||
torch.save(optimizer_saved_dict, filename + '.optim')
|
torch.save(optimizer_saved_dict, filename + '.optim')
|
||||||
|
|
||||||
|
@ -238,32 +237,33 @@ class Hypernetwork:
|
||||||
state_dict = torch.load(filename, map_location='cpu')
|
state_dict = torch.load(filename, map_location='cpu')
|
||||||
|
|
||||||
self.layer_structure = state_dict.get('layer_structure', [1, 2, 1])
|
self.layer_structure = state_dict.get('layer_structure', [1, 2, 1])
|
||||||
print(self.layer_structure)
|
self.optional_info = state_dict.get('optional_info', None)
|
||||||
optional_info = state_dict.get('optional_info', None)
|
|
||||||
if optional_info is not None:
|
|
||||||
print(f"INFO:\n {optional_info}\n")
|
|
||||||
self.optional_info = optional_info
|
|
||||||
self.activation_func = state_dict.get('activation_func', None)
|
self.activation_func = state_dict.get('activation_func', None)
|
||||||
print(f"Activation function is {self.activation_func}")
|
|
||||||
self.weight_init = state_dict.get('weight_initialization', 'Normal')
|
self.weight_init = state_dict.get('weight_initialization', 'Normal')
|
||||||
print(f"Weight initialization is {self.weight_init}")
|
|
||||||
self.add_layer_norm = state_dict.get('is_layer_norm', False)
|
self.add_layer_norm = state_dict.get('is_layer_norm', False)
|
||||||
print(f"Layer norm is set to {self.add_layer_norm}")
|
|
||||||
self.dropout_structure = state_dict.get('dropout_structure', None)
|
self.dropout_structure = state_dict.get('dropout_structure', None)
|
||||||
self.use_dropout = True if self.dropout_structure is not None and any(self.dropout_structure) else state_dict.get('use_dropout', False)
|
self.use_dropout = True if self.dropout_structure is not None and any(self.dropout_structure) else state_dict.get('use_dropout', False)
|
||||||
print(f"Dropout usage is set to {self.use_dropout}" )
|
|
||||||
self.activate_output = state_dict.get('activate_output', True)
|
self.activate_output = state_dict.get('activate_output', True)
|
||||||
print(f"Activate last layer is set to {self.activate_output}")
|
|
||||||
self.last_layer_dropout = state_dict.get('last_layer_dropout', False)
|
self.last_layer_dropout = state_dict.get('last_layer_dropout', False)
|
||||||
# Dropout structure should have same length as layer structure, Every digits should be in [0,1), and last digit must be 0.
|
# Dropout structure should have same length as layer structure, Every digits should be in [0,1), and last digit must be 0.
|
||||||
if self.dropout_structure is None:
|
if self.dropout_structure is None:
|
||||||
print("Using previous dropout structure")
|
|
||||||
self.dropout_structure = parse_dropout_structure(self.layer_structure, self.use_dropout, self.last_layer_dropout)
|
self.dropout_structure = parse_dropout_structure(self.layer_structure, self.use_dropout, self.last_layer_dropout)
|
||||||
print(f"Dropout structure is set to {self.dropout_structure}")
|
|
||||||
|
|
||||||
optimizer_saved_dict = torch.load(self.filename + '.optim', map_location = 'cpu') if os.path.exists(self.filename + '.optim') else {}
|
if shared.opts.print_hypernet_extra:
|
||||||
|
if self.optional_info is not None:
|
||||||
|
print(f" INFO:\n {self.optional_info}\n")
|
||||||
|
|
||||||
if sd_models.model_hash(filename) == optimizer_saved_dict.get('hash', None):
|
print(f" Layer structure: {self.layer_structure}")
|
||||||
|
print(f" Activation function: {self.activation_func}")
|
||||||
|
print(f" Weight initialization: {self.weight_init}")
|
||||||
|
print(f" Layer norm: {self.add_layer_norm}")
|
||||||
|
print(f" Dropout usage: {self.use_dropout}" )
|
||||||
|
print(f" Activate last layer: {self.activate_output}")
|
||||||
|
print(f" Dropout structure: {self.dropout_structure}")
|
||||||
|
|
||||||
|
optimizer_saved_dict = torch.load(self.filename + '.optim', map_location='cpu') if os.path.exists(self.filename + '.optim') else {}
|
||||||
|
|
||||||
|
if self.shorthash() == optimizer_saved_dict.get('hash', None):
|
||||||
self.optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
|
self.optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
|
||||||
else:
|
else:
|
||||||
self.optimizer_state_dict = None
|
self.optimizer_state_dict = None
|
||||||
|
@ -290,6 +290,11 @@ class Hypernetwork:
|
||||||
self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None)
|
self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None)
|
||||||
self.eval()
|
self.eval()
|
||||||
|
|
||||||
|
def shorthash(self):
|
||||||
|
sha256 = hashes.sha256(self.filename, f'hypernet/{self.name}')
|
||||||
|
|
||||||
|
return sha256[0:10]
|
||||||
|
|
||||||
|
|
||||||
def list_hypernetworks(path):
|
def list_hypernetworks(path):
|
||||||
res = {}
|
res = {}
|
||||||
|
@ -297,7 +302,7 @@ def list_hypernetworks(path):
|
||||||
name = os.path.splitext(os.path.basename(filename))[0]
|
name = os.path.splitext(os.path.basename(filename))[0]
|
||||||
# Prevent a hypothetical "None.pt" from being listed.
|
# Prevent a hypothetical "None.pt" from being listed.
|
||||||
if name != "None":
|
if name != "None":
|
||||||
res[name + f"({sd_models.model_hash(filename)})"] = filename
|
res[name] = filename
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
@ -498,6 +503,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
|
||||||
if clip_grad:
|
if clip_grad:
|
||||||
clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, initial_step, verbose=False)
|
clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, initial_step, verbose=False)
|
||||||
|
|
||||||
|
if shared.opts.training_enable_tensorboard:
|
||||||
|
tensorboard_writer = textual_inversion.tensorboard_setup(log_directory)
|
||||||
|
|
||||||
# dataset loading may take a while, so input validations and early returns should be done before this
|
# dataset loading may take a while, so input validations and early returns should be done before this
|
||||||
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
||||||
|
|
||||||
|
@ -507,7 +515,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
|
||||||
|
|
||||||
if shared.opts.save_training_settings_to_txt:
|
if shared.opts.save_training_settings_to_txt:
|
||||||
saved_params = dict(
|
saved_params = dict(
|
||||||
model_name=checkpoint.model_name, model_hash=checkpoint.hash, num_of_dataset_images=len(ds),
|
model_name=checkpoint.model_name, model_hash=checkpoint.shorthash, num_of_dataset_images=len(ds),
|
||||||
**{field: getattr(hypernetwork, field) for field in ['layer_structure', 'activation_func', 'weight_init', 'add_layer_norm', 'use_dropout', ]}
|
**{field: getattr(hypernetwork, field) for field in ['layer_structure', 'activation_func', 'weight_init', 'add_layer_norm', 'use_dropout', ]}
|
||||||
)
|
)
|
||||||
logging.save_settings_to_file(log_directory, {**saved_params, **locals()})
|
logging.save_settings_to_file(log_directory, {**saved_params, **locals()})
|
||||||
|
@ -619,7 +627,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
|
||||||
epoch_num = hypernetwork.step // steps_per_epoch
|
epoch_num = hypernetwork.step // steps_per_epoch
|
||||||
epoch_step = hypernetwork.step % steps_per_epoch
|
epoch_step = hypernetwork.step % steps_per_epoch
|
||||||
|
|
||||||
pbar.set_description(f"[Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}")
|
description = f"Training hypernetwork [Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}"
|
||||||
|
pbar.set_description(description)
|
||||||
|
shared.state.textinfo = description
|
||||||
if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0:
|
if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0:
|
||||||
# Before saving, change name to match current checkpoint.
|
# Before saving, change name to match current checkpoint.
|
||||||
hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}'
|
hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}'
|
||||||
|
@ -630,6 +640,14 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
|
||||||
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, last_saved_file)
|
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, last_saved_file)
|
||||||
hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
|
hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if shared.opts.training_enable_tensorboard:
|
||||||
|
epoch_num = hypernetwork.step // len(ds)
|
||||||
|
epoch_step = hypernetwork.step - (epoch_num * len(ds)) + 1
|
||||||
|
|
||||||
|
textual_inversion.tensorboard_add(tensorboard_writer, loss=mean_loss, global_step=hypernetwork.step, step=epoch_step, learn_rate=scheduler.learn_rate, epoch_num=epoch_num)
|
||||||
|
|
||||||
textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, steps_per_epoch, {
|
textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, steps_per_epoch, {
|
||||||
"loss": f"{loss_step:.7f}",
|
"loss": f"{loss_step:.7f}",
|
||||||
"learn_rate": scheduler.learn_rate
|
"learn_rate": scheduler.learn_rate
|
||||||
|
@ -672,6 +690,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
|
||||||
processed = processing.process_images(p)
|
processed = processing.process_images(p)
|
||||||
image = processed.images[0] if len(processed.images) > 0 else None
|
image = processed.images[0] if len(processed.images) > 0 else None
|
||||||
|
|
||||||
|
if shared.opts.training_enable_tensorboard and shared.opts.training_tensorboard_save_images:
|
||||||
|
textual_inversion.tensorboard_add_image(tensorboard_writer, f"Validation at epoch {epoch_num}", image, hypernetwork.step)
|
||||||
|
|
||||||
if unload:
|
if unload:
|
||||||
shared.sd_model.cond_stage_model.to(devices.cpu)
|
shared.sd_model.cond_stage_model.to(devices.cpu)
|
||||||
shared.sd_model.first_stage_model.to(devices.cpu)
|
shared.sd_model.first_stage_model.to(devices.cpu)
|
||||||
|
@ -722,7 +743,7 @@ def save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename):
|
||||||
old_sd_checkpoint = hypernetwork.sd_checkpoint if hasattr(hypernetwork, "sd_checkpoint") else None
|
old_sd_checkpoint = hypernetwork.sd_checkpoint if hasattr(hypernetwork, "sd_checkpoint") else None
|
||||||
old_sd_checkpoint_name = hypernetwork.sd_checkpoint_name if hasattr(hypernetwork, "sd_checkpoint_name") else None
|
old_sd_checkpoint_name = hypernetwork.sd_checkpoint_name if hasattr(hypernetwork, "sd_checkpoint_name") else None
|
||||||
try:
|
try:
|
||||||
hypernetwork.sd_checkpoint = checkpoint.hash
|
hypernetwork.sd_checkpoint = checkpoint.shorthash
|
||||||
hypernetwork.sd_checkpoint_name = checkpoint.model_name
|
hypernetwork.sd_checkpoint_name = checkpoint.model_name
|
||||||
hypernetwork.name = hypernetwork_name
|
hypernetwork.name = hypernetwork_name
|
||||||
hypernetwork.save(filename)
|
hypernetwork.save(filename)
|
||||||
|
|
|
@ -59,38 +59,34 @@ def process_batch(p, input_dir, output_dir, args):
|
||||||
processed_image.save(os.path.join(output_dir, filename))
|
processed_image.save(os.path.join(output_dir, filename))
|
||||||
|
|
||||||
|
|
||||||
def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, init_img_with_mask, init_img_with_mask_orig, init_img_inpaint, init_mask_inpaint, mask_mode, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args):
|
def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args):
|
||||||
is_inpaint = mode == 1
|
is_batch = mode == 5
|
||||||
is_batch = mode == 2
|
|
||||||
|
|
||||||
if is_inpaint:
|
if mode == 0: # img2img
|
||||||
# Drawn mask
|
image = init_img.convert("RGB")
|
||||||
if mask_mode == 0:
|
mask = None
|
||||||
is_mask_sketch = isinstance(init_img_with_mask, dict)
|
elif mode == 1: # img2img sketch
|
||||||
is_mask_paint = not is_mask_sketch
|
image = sketch.convert("RGB")
|
||||||
if is_mask_sketch:
|
mask = None
|
||||||
# Sketch: mask iff. not transparent
|
elif mode == 2: # inpaint
|
||||||
image, mask = init_img_with_mask["image"], init_img_with_mask["mask"]
|
image, mask = init_img_with_mask["image"], init_img_with_mask["mask"]
|
||||||
alpha_mask = ImageOps.invert(image.split()[-1]).convert('L').point(lambda x: 255 if x > 0 else 0, mode='1')
|
alpha_mask = ImageOps.invert(image.split()[-1]).convert('L').point(lambda x: 255 if x > 0 else 0, mode='1')
|
||||||
mask = ImageChops.lighter(alpha_mask, mask.convert('L')).convert('L')
|
mask = ImageChops.lighter(alpha_mask, mask.convert('L')).convert('L')
|
||||||
else:
|
image = image.convert("RGB")
|
||||||
# Color-sketch: mask iff. painted over
|
elif mode == 3: # inpaint sketch
|
||||||
image = init_img_with_mask
|
image = inpaint_color_sketch
|
||||||
orig = init_img_with_mask_orig or init_img_with_mask
|
orig = inpaint_color_sketch_orig or inpaint_color_sketch
|
||||||
pred = np.any(np.array(image) != np.array(orig), axis=-1)
|
pred = np.any(np.array(image) != np.array(orig), axis=-1)
|
||||||
mask = Image.fromarray(pred.astype(np.uint8) * 255, "L")
|
mask = Image.fromarray(pred.astype(np.uint8) * 255, "L")
|
||||||
mask = ImageEnhance.Brightness(mask).enhance(1 - mask_alpha / 100)
|
mask = ImageEnhance.Brightness(mask).enhance(1 - mask_alpha / 100)
|
||||||
blur = ImageFilter.GaussianBlur(mask_blur)
|
blur = ImageFilter.GaussianBlur(mask_blur)
|
||||||
image = Image.composite(image.filter(blur), orig, mask.filter(blur))
|
image = Image.composite(image.filter(blur), orig, mask.filter(blur))
|
||||||
|
|
||||||
image = image.convert("RGB")
|
image = image.convert("RGB")
|
||||||
# Uploaded mask
|
elif mode == 4: # inpaint upload mask
|
||||||
else:
|
|
||||||
image = init_img_inpaint
|
image = init_img_inpaint
|
||||||
mask = init_mask_inpaint
|
mask = init_mask_inpaint
|
||||||
# No mask
|
|
||||||
else:
|
else:
|
||||||
image = init_img
|
image = None
|
||||||
mask = None
|
mask = None
|
||||||
|
|
||||||
# Use the EXIF orientation of photos taken by smartphones.
|
# Use the EXIF orientation of photos taken by smartphones.
|
||||||
|
|
|
@ -437,7 +437,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
|
||||||
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
|
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
|
||||||
"Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
|
"Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
|
||||||
"Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name),
|
"Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name),
|
||||||
"Hypernet hash": (None if shared.loaded_hypernetwork is None else sd_models.model_hash(shared.loaded_hypernetwork.filename)),
|
"Hypernet hash": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.shorthash()),
|
||||||
"Hypernet strength": (None if shared.loaded_hypernetwork is None or shared.opts.sd_hypernetwork_strength >= 1 else shared.opts.sd_hypernetwork_strength),
|
"Hypernet strength": (None if shared.loaded_hypernetwork is None or shared.opts.sd_hypernetwork_strength >= 1 else shared.opts.sd_hypernetwork_strength),
|
||||||
"Batch size": (None if p.batch_size < 2 else p.batch_size),
|
"Batch size": (None if p.batch_size < 2 else p.batch_size),
|
||||||
"Batch pos": (None if p.batch_size < 2 else position_in_batch),
|
"Batch pos": (None if p.batch_size < 2 else position_in_batch),
|
||||||
|
@ -531,16 +531,16 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||||
def infotext(iteration=0, position_in_batch=0):
|
def infotext(iteration=0, position_in_batch=0):
|
||||||
return create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, comments, iteration, position_in_batch)
|
return create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, comments, iteration, position_in_batch)
|
||||||
|
|
||||||
with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file:
|
|
||||||
processed = Processed(p, [], p.seed, "")
|
|
||||||
file.write(processed.infotext(p, 0))
|
|
||||||
|
|
||||||
if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings:
|
if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings:
|
||||||
model_hijack.embedding_db.load_textual_inversion_embeddings()
|
model_hijack.embedding_db.load_textual_inversion_embeddings()
|
||||||
|
|
||||||
if p.scripts is not None:
|
if p.scripts is not None:
|
||||||
p.scripts.process(p)
|
p.scripts.process(p)
|
||||||
|
|
||||||
|
with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file:
|
||||||
|
processed = Processed(p, [], p.seed, "")
|
||||||
|
file.write(processed.infotext(p, 0))
|
||||||
|
|
||||||
infotexts = []
|
infotexts = []
|
||||||
output_images = []
|
output_images = []
|
||||||
|
|
||||||
|
|
|
@ -49,6 +49,8 @@ def get_learned_conditioning_prompt_schedules(prompts, steps):
|
||||||
[[5, 'a c'], [10, 'a {b|d{ c']]
|
[[5, 'a c'], [10, 'a {b|d{ c']]
|
||||||
>>> g("((a][:b:c [d:3]")
|
>>> g("((a][:b:c [d:3]")
|
||||||
[[3, '((a][:b:c '], [10, '((a][:b:c d']]
|
[[3, '((a][:b:c '], [10, '((a][:b:c d']]
|
||||||
|
>>> g("[a|(b:1.1)]")
|
||||||
|
[[1, 'a'], [2, '(b:1.1)'], [3, 'a'], [4, '(b:1.1)'], [5, 'a'], [6, '(b:1.1)'], [7, 'a'], [8, '(b:1.1)'], [9, 'a'], [10, '(b:1.1)']]
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def collect_steps(steps, tree):
|
def collect_steps(steps, tree):
|
||||||
|
@ -84,7 +86,7 @@ def get_learned_conditioning_prompt_schedules(prompts, steps):
|
||||||
yield args[0].value
|
yield args[0].value
|
||||||
def __default__(self, data, children, meta):
|
def __default__(self, data, children, meta):
|
||||||
for child in children:
|
for child in children:
|
||||||
yield from child
|
yield child
|
||||||
return AtStep().transform(tree)
|
return AtStep().transform(tree)
|
||||||
|
|
||||||
def get_schedule(prompt):
|
def get_schedule(prompt):
|
||||||
|
|
|
@ -2,7 +2,7 @@ import sys
|
||||||
import traceback
|
import traceback
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
import inspect
|
import inspect
|
||||||
from typing import Optional
|
from typing import Optional, Dict, Any
|
||||||
|
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from gradio import Blocks
|
from gradio import Blocks
|
||||||
|
@ -71,6 +71,7 @@ callback_map = dict(
|
||||||
callbacks_before_component=[],
|
callbacks_before_component=[],
|
||||||
callbacks_after_component=[],
|
callbacks_after_component=[],
|
||||||
callbacks_image_grid=[],
|
callbacks_image_grid=[],
|
||||||
|
callbacks_infotext_pasted=[],
|
||||||
callbacks_script_unloaded=[],
|
callbacks_script_unloaded=[],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -172,6 +173,14 @@ def image_grid_callback(params: ImageGridLoopParams):
|
||||||
report_exception(c, 'image_grid')
|
report_exception(c, 'image_grid')
|
||||||
|
|
||||||
|
|
||||||
|
def infotext_pasted_callback(infotext: str, params: Dict[str, Any]):
|
||||||
|
for c in callback_map['callbacks_infotext_pasted']:
|
||||||
|
try:
|
||||||
|
c.callback(infotext, params)
|
||||||
|
except Exception:
|
||||||
|
report_exception(c, 'infotext_pasted')
|
||||||
|
|
||||||
|
|
||||||
def script_unloaded_callback():
|
def script_unloaded_callback():
|
||||||
for c in reversed(callback_map['callbacks_script_unloaded']):
|
for c in reversed(callback_map['callbacks_script_unloaded']):
|
||||||
try:
|
try:
|
||||||
|
@ -290,6 +299,15 @@ def on_image_grid(callback):
|
||||||
add_callback(callback_map['callbacks_image_grid'], callback)
|
add_callback(callback_map['callbacks_image_grid'], callback)
|
||||||
|
|
||||||
|
|
||||||
|
def on_infotext_pasted(callback):
|
||||||
|
"""register a function to be called before applying an infotext.
|
||||||
|
The callback is called with two arguments:
|
||||||
|
- infotext: str - raw infotext.
|
||||||
|
- result: Dict[str, any] - parsed infotext parameters.
|
||||||
|
"""
|
||||||
|
add_callback(callback_map['callbacks_infotext_pasted'], callback)
|
||||||
|
|
||||||
|
|
||||||
def on_script_unloaded(callback):
|
def on_script_unloaded(callback):
|
||||||
"""register a function to be called before the script is unloaded. Any hooks/hijacks/monkeying about that
|
"""register a function to be called before the script is unloaded. Any hooks/hijacks/monkeying about that
|
||||||
the script did should be reverted here"""
|
the script did should be reverted here"""
|
||||||
|
|
|
@ -152,7 +152,7 @@ def basedir():
|
||||||
|
|
||||||
scripts_data = []
|
scripts_data = []
|
||||||
ScriptFile = namedtuple("ScriptFile", ["basedir", "filename", "path"])
|
ScriptFile = namedtuple("ScriptFile", ["basedir", "filename", "path"])
|
||||||
ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir"])
|
ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir", "module"])
|
||||||
|
|
||||||
|
|
||||||
def list_scripts(scriptdirname, extension):
|
def list_scripts(scriptdirname, extension):
|
||||||
|
@ -206,7 +206,7 @@ def load_scripts():
|
||||||
|
|
||||||
for key, script_class in module.__dict__.items():
|
for key, script_class in module.__dict__.items():
|
||||||
if type(script_class) == type and issubclass(script_class, Script):
|
if type(script_class) == type and issubclass(script_class, Script):
|
||||||
scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir))
|
scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module))
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
print(f"Error loading script: {scriptfile.filename}", file=sys.stderr)
|
print(f"Error loading script: {scriptfile.filename}", file=sys.stderr)
|
||||||
|
@ -241,7 +241,7 @@ class ScriptRunner:
|
||||||
self.alwayson_scripts.clear()
|
self.alwayson_scripts.clear()
|
||||||
self.selectable_scripts.clear()
|
self.selectable_scripts.clear()
|
||||||
|
|
||||||
for script_class, path, basedir in scripts_data:
|
for script_class, path, basedir, script_module in scripts_data:
|
||||||
script = script_class()
|
script = script_class()
|
||||||
script.filename = path
|
script.filename = path
|
||||||
script.is_txt2img = not is_img2img
|
script.is_txt2img = not is_img2img
|
||||||
|
|
|
@ -20,6 +20,19 @@ class DisableInitialization:
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.replaced = []
|
||||||
|
|
||||||
|
def replace(self, obj, field, func):
|
||||||
|
original = getattr(obj, field, None)
|
||||||
|
if original is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
self.replaced.append((obj, field, original))
|
||||||
|
setattr(obj, field, func)
|
||||||
|
|
||||||
|
return original
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
def do_nothing(*args, **kwargs):
|
def do_nothing(*args, **kwargs):
|
||||||
pass
|
pass
|
||||||
|
@ -37,11 +50,14 @@ class DisableInitialization:
|
||||||
def transformers_utils_hub_get_file_from_cache(original, url, *args, **kwargs):
|
def transformers_utils_hub_get_file_from_cache(original, url, *args, **kwargs):
|
||||||
|
|
||||||
# this file is always 404, prevent making request
|
# this file is always 404, prevent making request
|
||||||
if url == 'https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/added_tokens.json':
|
if url == 'https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/added_tokens.json' or url == 'openai/clip-vit-large-patch14' and args[0] == 'added_tokens.json':
|
||||||
raise transformers.utils.hub.EntryNotFoundError
|
return None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return original(url, *args, local_files_only=True, **kwargs)
|
res = original(url, *args, local_files_only=True, **kwargs)
|
||||||
|
if res is None:
|
||||||
|
res = original(url, *args, local_files_only=False, **kwargs)
|
||||||
|
return res
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return original(url, *args, local_files_only=False, **kwargs)
|
return original(url, *args, local_files_only=False, **kwargs)
|
||||||
|
|
||||||
|
@ -54,42 +70,19 @@ class DisableInitialization:
|
||||||
def transformers_configuration_utils_cached_file(url, *args, local_files_only=False, **kwargs):
|
def transformers_configuration_utils_cached_file(url, *args, local_files_only=False, **kwargs):
|
||||||
return transformers_utils_hub_get_file_from_cache(self.transformers_configuration_utils_cached_file, url, *args, **kwargs)
|
return transformers_utils_hub_get_file_from_cache(self.transformers_configuration_utils_cached_file, url, *args, **kwargs)
|
||||||
|
|
||||||
self.init_kaiming_uniform = torch.nn.init.kaiming_uniform_
|
self.replace(torch.nn.init, 'kaiming_uniform_', do_nothing)
|
||||||
self.init_no_grad_normal = torch.nn.init._no_grad_normal_
|
self.replace(torch.nn.init, '_no_grad_normal_', do_nothing)
|
||||||
self.init_no_grad_uniform_ = torch.nn.init._no_grad_uniform_
|
self.replace(torch.nn.init, '_no_grad_uniform_', do_nothing)
|
||||||
self.create_model_and_transforms = open_clip.create_model_and_transforms
|
self.create_model_and_transforms = self.replace(open_clip, 'create_model_and_transforms', create_model_and_transforms_without_pretrained)
|
||||||
self.CLIPTextModel_from_pretrained = ldm.modules.encoders.modules.CLIPTextModel.from_pretrained
|
self.CLIPTextModel_from_pretrained = self.replace(ldm.modules.encoders.modules.CLIPTextModel, 'from_pretrained', CLIPTextModel_from_pretrained)
|
||||||
self.transformers_modeling_utils_load_pretrained_model = getattr(transformers.modeling_utils.PreTrainedModel, '_load_pretrained_model', None)
|
self.transformers_modeling_utils_load_pretrained_model = self.replace(transformers.modeling_utils.PreTrainedModel, '_load_pretrained_model', transformers_modeling_utils_load_pretrained_model)
|
||||||
self.transformers_tokenization_utils_base_cached_file = getattr(transformers.tokenization_utils_base, 'cached_file', None)
|
self.transformers_tokenization_utils_base_cached_file = self.replace(transformers.tokenization_utils_base, 'cached_file', transformers_tokenization_utils_base_cached_file)
|
||||||
self.transformers_configuration_utils_cached_file = getattr(transformers.configuration_utils, 'cached_file', None)
|
self.transformers_configuration_utils_cached_file = self.replace(transformers.configuration_utils, 'cached_file', transformers_configuration_utils_cached_file)
|
||||||
self.transformers_utils_hub_get_from_cache = getattr(transformers.utils.hub, 'get_from_cache', None)
|
self.transformers_utils_hub_get_from_cache = self.replace(transformers.utils.hub, 'get_from_cache', transformers_utils_hub_get_from_cache)
|
||||||
|
|
||||||
torch.nn.init.kaiming_uniform_ = do_nothing
|
|
||||||
torch.nn.init._no_grad_normal_ = do_nothing
|
|
||||||
torch.nn.init._no_grad_uniform_ = do_nothing
|
|
||||||
open_clip.create_model_and_transforms = create_model_and_transforms_without_pretrained
|
|
||||||
ldm.modules.encoders.modules.CLIPTextModel.from_pretrained = CLIPTextModel_from_pretrained
|
|
||||||
if self.transformers_modeling_utils_load_pretrained_model is not None:
|
|
||||||
transformers.modeling_utils.PreTrainedModel._load_pretrained_model = transformers_modeling_utils_load_pretrained_model
|
|
||||||
if self.transformers_tokenization_utils_base_cached_file is not None:
|
|
||||||
transformers.tokenization_utils_base.cached_file = transformers_tokenization_utils_base_cached_file
|
|
||||||
if self.transformers_configuration_utils_cached_file is not None:
|
|
||||||
transformers.configuration_utils.cached_file = transformers_configuration_utils_cached_file
|
|
||||||
if self.transformers_utils_hub_get_from_cache is not None:
|
|
||||||
transformers.utils.hub.get_from_cache = transformers_utils_hub_get_from_cache
|
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
torch.nn.init.kaiming_uniform_ = self.init_kaiming_uniform
|
for obj, field, original in self.replaced:
|
||||||
torch.nn.init._no_grad_normal_ = self.init_no_grad_normal
|
setattr(obj, field, original)
|
||||||
torch.nn.init._no_grad_uniform_ = self.init_no_grad_uniform_
|
|
||||||
open_clip.create_model_and_transforms = self.create_model_and_transforms
|
self.replaced.clear()
|
||||||
ldm.modules.encoders.modules.CLIPTextModel.from_pretrained = self.CLIPTextModel_from_pretrained
|
|
||||||
if self.transformers_modeling_utils_load_pretrained_model is not None:
|
|
||||||
transformers.modeling_utils.PreTrainedModel._load_pretrained_model = self.transformers_modeling_utils_load_pretrained_model
|
|
||||||
if self.transformers_tokenization_utils_base_cached_file is not None:
|
|
||||||
transformers.utils.hub.cached_file = self.transformers_tokenization_utils_base_cached_file
|
|
||||||
if self.transformers_configuration_utils_cached_file is not None:
|
|
||||||
transformers.utils.hub.cached_file = self.transformers_configuration_utils_cached_file
|
|
||||||
if self.transformers_utils_hub_get_from_cache is not None:
|
|
||||||
transformers.utils.hub.get_from_cache = self.transformers_utils_hub_get_from_cache
|
|
||||||
|
|
||||||
|
|
|
@ -14,17 +14,58 @@ import ldm.modules.midas as midas
|
||||||
|
|
||||||
from ldm.util import instantiate_from_config
|
from ldm.util import instantiate_from_config
|
||||||
|
|
||||||
from modules import shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors
|
from modules import shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes
|
||||||
from modules.paths import models_path
|
from modules.paths import models_path
|
||||||
from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting
|
from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting
|
||||||
|
|
||||||
model_dir = "Stable-diffusion"
|
model_dir = "Stable-diffusion"
|
||||||
model_path = os.path.abspath(os.path.join(models_path, model_dir))
|
model_path = os.path.abspath(os.path.join(models_path, model_dir))
|
||||||
|
|
||||||
CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name'])
|
|
||||||
checkpoints_list = {}
|
checkpoints_list = {}
|
||||||
|
checkpoint_alisases = {}
|
||||||
checkpoints_loaded = collections.OrderedDict()
|
checkpoints_loaded = collections.OrderedDict()
|
||||||
|
|
||||||
|
|
||||||
|
class CheckpointInfo:
|
||||||
|
def __init__(self, filename):
|
||||||
|
self.filename = filename
|
||||||
|
abspath = os.path.abspath(filename)
|
||||||
|
|
||||||
|
if shared.cmd_opts.ckpt_dir is not None and abspath.startswith(shared.cmd_opts.ckpt_dir):
|
||||||
|
name = abspath.replace(shared.cmd_opts.ckpt_dir, '')
|
||||||
|
elif abspath.startswith(model_path):
|
||||||
|
name = abspath.replace(model_path, '')
|
||||||
|
else:
|
||||||
|
name = os.path.basename(filename)
|
||||||
|
|
||||||
|
if name.startswith("\\") or name.startswith("/"):
|
||||||
|
name = name[1:]
|
||||||
|
|
||||||
|
self.title = name
|
||||||
|
self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]
|
||||||
|
self.hash = model_hash(filename)
|
||||||
|
|
||||||
|
self.sha256 = hashes.sha256_from_cache(self.filename, "checkpoint/" + self.title)
|
||||||
|
self.shorthash = self.sha256[0:10] if self.sha256 else None
|
||||||
|
|
||||||
|
self.ids = [self.hash, self.model_name, self.title, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256] if self.shorthash else [])
|
||||||
|
|
||||||
|
def register(self):
|
||||||
|
checkpoints_list[self.title] = self
|
||||||
|
for id in self.ids:
|
||||||
|
checkpoint_alisases[id] = self
|
||||||
|
|
||||||
|
def calculate_shorthash(self):
|
||||||
|
self.sha256 = hashes.sha256(self.filename, "checkpoint/" + self.title)
|
||||||
|
self.shorthash = self.sha256[0:10]
|
||||||
|
|
||||||
|
if self.shorthash not in self.ids:
|
||||||
|
self.ids += [self.shorthash, self.sha256]
|
||||||
|
self.register()
|
||||||
|
|
||||||
|
return self.shorthash
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
|
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
|
||||||
|
|
||||||
|
@ -44,9 +85,13 @@ def setup_model():
|
||||||
|
|
||||||
|
|
||||||
def checkpoint_tiles():
|
def checkpoint_tiles():
|
||||||
convert = lambda name: int(name) if name.isdigit() else name.lower()
|
def convert(name):
|
||||||
alphanumeric_key = lambda key: [convert(c) for c in re.split('([0-9]+)', key)]
|
return int(name) if name.isdigit() else name.lower()
|
||||||
return sorted([x.title for x in checkpoints_list.values()], key = alphanumeric_key)
|
|
||||||
|
def alphanumeric_key(key):
|
||||||
|
return [convert(c) for c in re.split('([0-9]+)', key)]
|
||||||
|
|
||||||
|
return sorted([x.title for x in checkpoints_list.values()], key=alphanumeric_key)
|
||||||
|
|
||||||
|
|
||||||
def find_checkpoint_config(info):
|
def find_checkpoint_config(info):
|
||||||
|
@ -62,48 +107,38 @@ def find_checkpoint_config(info):
|
||||||
|
|
||||||
def list_models():
|
def list_models():
|
||||||
checkpoints_list.clear()
|
checkpoints_list.clear()
|
||||||
|
checkpoint_alisases.clear()
|
||||||
model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"], ext_blacklist=[".vae.safetensors"])
|
model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"], ext_blacklist=[".vae.safetensors"])
|
||||||
|
|
||||||
def modeltitle(path, shorthash):
|
|
||||||
abspath = os.path.abspath(path)
|
|
||||||
|
|
||||||
if shared.cmd_opts.ckpt_dir is not None and abspath.startswith(shared.cmd_opts.ckpt_dir):
|
|
||||||
name = abspath.replace(shared.cmd_opts.ckpt_dir, '')
|
|
||||||
elif abspath.startswith(model_path):
|
|
||||||
name = abspath.replace(model_path, '')
|
|
||||||
else:
|
|
||||||
name = os.path.basename(path)
|
|
||||||
|
|
||||||
if name.startswith("\\") or name.startswith("/"):
|
|
||||||
name = name[1:]
|
|
||||||
|
|
||||||
shortname = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]
|
|
||||||
|
|
||||||
return f'{name} [{shorthash}]', shortname
|
|
||||||
|
|
||||||
cmd_ckpt = shared.cmd_opts.ckpt
|
cmd_ckpt = shared.cmd_opts.ckpt
|
||||||
if os.path.exists(cmd_ckpt):
|
if os.path.exists(cmd_ckpt):
|
||||||
h = model_hash(cmd_ckpt)
|
checkpoint_info = CheckpointInfo(cmd_ckpt)
|
||||||
title, short_model_name = modeltitle(cmd_ckpt, h)
|
checkpoint_info.register()
|
||||||
checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name)
|
|
||||||
shared.opts.data['sd_model_checkpoint'] = title
|
shared.opts.data['sd_model_checkpoint'] = checkpoint_info.title
|
||||||
elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
|
elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
|
||||||
print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr)
|
print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr)
|
||||||
|
|
||||||
for filename in model_list:
|
for filename in model_list:
|
||||||
h = model_hash(filename)
|
checkpoint_info = CheckpointInfo(filename)
|
||||||
title, short_model_name = modeltitle(filename, h)
|
checkpoint_info.register()
|
||||||
|
|
||||||
checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name)
|
|
||||||
|
|
||||||
|
|
||||||
def get_closet_checkpoint_match(searchString):
|
def get_closet_checkpoint_match(search_string):
|
||||||
applicable = sorted([info for info in checkpoints_list.values() if searchString in info.title], key = lambda x:len(x.title))
|
checkpoint_info = checkpoint_alisases.get(search_string, None)
|
||||||
if len(applicable) > 0:
|
if checkpoint_info is not None:
|
||||||
return applicable[0]
|
return checkpoint_info
|
||||||
|
|
||||||
|
found = sorted([info for info in checkpoints_list.values() if search_string in info.title], key=lambda x: len(x.title))
|
||||||
|
if found:
|
||||||
|
return found[0]
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def model_hash(filename):
|
def model_hash(filename):
|
||||||
|
"""old hash that only looks at a small part of the file and is prone to collisions"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with open(filename, "rb") as file:
|
with open(filename, "rb") as file:
|
||||||
import hashlib
|
import hashlib
|
||||||
|
@ -119,7 +154,7 @@ def model_hash(filename):
|
||||||
def select_checkpoint():
|
def select_checkpoint():
|
||||||
model_checkpoint = shared.opts.sd_model_checkpoint
|
model_checkpoint = shared.opts.sd_model_checkpoint
|
||||||
|
|
||||||
checkpoint_info = checkpoints_list.get(model_checkpoint, None)
|
checkpoint_info = checkpoint_alisases.get(model_checkpoint, None)
|
||||||
if checkpoint_info is not None:
|
if checkpoint_info is not None:
|
||||||
return checkpoint_info
|
return checkpoint_info
|
||||||
|
|
||||||
|
@ -189,9 +224,8 @@ def read_state_dict(checkpoint_file, print_global_state=False, map_location=None
|
||||||
return sd
|
return sd
|
||||||
|
|
||||||
|
|
||||||
def load_model_weights(model, checkpoint_info, vae_file="auto"):
|
def load_model_weights(model, checkpoint_info: CheckpointInfo):
|
||||||
checkpoint_file = checkpoint_info.filename
|
sd_model_hash = checkpoint_info.calculate_shorthash()
|
||||||
sd_model_hash = checkpoint_info.hash
|
|
||||||
|
|
||||||
cache_enabled = shared.opts.sd_checkpoint_cache > 0
|
cache_enabled = shared.opts.sd_checkpoint_cache > 0
|
||||||
|
|
||||||
|
@ -201,9 +235,9 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"):
|
||||||
model.load_state_dict(checkpoints_loaded[checkpoint_info])
|
model.load_state_dict(checkpoints_loaded[checkpoint_info])
|
||||||
else:
|
else:
|
||||||
# load from file
|
# load from file
|
||||||
print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}")
|
print(f"Loading weights [{sd_model_hash}] from {checkpoint_info.filename}")
|
||||||
|
|
||||||
sd = read_state_dict(checkpoint_file)
|
sd = read_state_dict(checkpoint_info.filename)
|
||||||
model.load_state_dict(sd, strict=False)
|
model.load_state_dict(sd, strict=False)
|
||||||
del sd
|
del sd
|
||||||
|
|
||||||
|
@ -235,15 +269,16 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"):
|
||||||
checkpoints_loaded.popitem(last=False) # LRU
|
checkpoints_loaded.popitem(last=False) # LRU
|
||||||
|
|
||||||
model.sd_model_hash = sd_model_hash
|
model.sd_model_hash = sd_model_hash
|
||||||
model.sd_model_checkpoint = checkpoint_file
|
model.sd_model_checkpoint = checkpoint_info.filename
|
||||||
model.sd_checkpoint_info = checkpoint_info
|
model.sd_checkpoint_info = checkpoint_info
|
||||||
|
shared.opts.data["sd_checkpoint_hash"] = checkpoint_info.sha256
|
||||||
|
|
||||||
model.logvar = model.logvar.to(devices.device) # fix for training
|
model.logvar = model.logvar.to(devices.device) # fix for training
|
||||||
|
|
||||||
sd_vae.delete_base_vae()
|
sd_vae.delete_base_vae()
|
||||||
sd_vae.clear_loaded_vae()
|
sd_vae.clear_loaded_vae()
|
||||||
vae_file = sd_vae.resolve_vae(checkpoint_file, vae_file=vae_file)
|
vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename)
|
||||||
sd_vae.load_vae(model, vae_file)
|
sd_vae.load_vae(model, vae_file, vae_source)
|
||||||
|
|
||||||
|
|
||||||
def enable_midas_autodownload():
|
def enable_midas_autodownload():
|
||||||
|
@ -333,10 +368,15 @@ def load_model(checkpoint_info=None):
|
||||||
|
|
||||||
timer = Timer()
|
timer = Timer()
|
||||||
|
|
||||||
|
sd_model = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with sd_disable_initialization.DisableInitialization():
|
with sd_disable_initialization.DisableInitialization():
|
||||||
sd_model = instantiate_from_config(sd_config.model)
|
sd_model = instantiate_from_config(sd_config.model)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if sd_model is None:
|
||||||
print('Failed to create model quickly; will retry using slow method.', file=sys.stderr)
|
print('Failed to create model quickly; will retry using slow method.', file=sys.stderr)
|
||||||
sd_model = instantiate_from_config(sd_config.model)
|
sd_model = instantiate_from_config(sd_config.model)
|
||||||
|
|
||||||
|
|
|
@ -138,7 +138,7 @@ def samples_to_image_grid(samples, approximation=None):
|
||||||
def store_latent(decoded):
|
def store_latent(decoded):
|
||||||
state.current_latent = decoded
|
state.current_latent = decoded
|
||||||
|
|
||||||
if opts.show_progress_every_n_steps > 0 and shared.state.sampling_step % opts.show_progress_every_n_steps == 0:
|
if opts.live_previews_enable and opts.show_progress_every_n_steps > 0 and shared.state.sampling_step % opts.show_progress_every_n_steps == 0:
|
||||||
if not shared.parallel_processing_allowed:
|
if not shared.parallel_processing_allowed:
|
||||||
shared.state.current_image = sample_to_image(decoded)
|
shared.state.current_image = sample_to_image(decoded)
|
||||||
|
|
||||||
|
@ -267,7 +267,6 @@ class VanillaStableDiffusionSampler:
|
||||||
conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]}
|
conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]}
|
||||||
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
|
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
|
||||||
|
|
||||||
|
|
||||||
samples = self.launch_sampling(t_enc + 1, lambda: self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning))
|
samples = self.launch_sampling(t_enc + 1, lambda: self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning))
|
||||||
|
|
||||||
return samples
|
return samples
|
||||||
|
@ -352,6 +351,11 @@ class CFGDenoiser(torch.nn.Module):
|
||||||
|
|
||||||
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond={"c_crossattn": [uncond], "c_concat": [image_cond_in[-uncond.shape[0]:]]})
|
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond={"c_crossattn": [uncond], "c_concat": [image_cond_in[-uncond.shape[0]:]]})
|
||||||
|
|
||||||
|
if opts.live_preview_content == "Prompt":
|
||||||
|
store_latent(x_out[0:uncond.shape[0]])
|
||||||
|
elif opts.live_preview_content == "Negative prompt":
|
||||||
|
store_latent(x_out[-uncond.shape[0]:])
|
||||||
|
|
||||||
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
|
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
|
||||||
|
|
||||||
if self.mask is not None:
|
if self.mask is not None:
|
||||||
|
@ -423,6 +427,7 @@ class KDiffusionSampler:
|
||||||
def callback_state(self, d):
|
def callback_state(self, d):
|
||||||
step = d['i']
|
step = d['i']
|
||||||
latent = d["denoised"]
|
latent = d["denoised"]
|
||||||
|
if opts.live_preview_content == "Combined":
|
||||||
store_latent(latent)
|
store_latent(latent)
|
||||||
self.last_latent = latent
|
self.last_latent = latent
|
||||||
|
|
||||||
|
|
|
@ -9,23 +9,9 @@ import glob
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
|
||||||
|
|
||||||
model_dir = "Stable-diffusion"
|
vae_path = os.path.abspath(os.path.join(models_path, "VAE"))
|
||||||
model_path = os.path.abspath(os.path.join(models_path, model_dir))
|
|
||||||
vae_dir = "VAE"
|
|
||||||
vae_path = os.path.abspath(os.path.join(models_path, vae_dir))
|
|
||||||
|
|
||||||
|
|
||||||
vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"}
|
vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"}
|
||||||
|
vae_dict = {}
|
||||||
|
|
||||||
default_vae_dict = {"auto": "auto", "None": None, None: None}
|
|
||||||
default_vae_list = ["auto", "None"]
|
|
||||||
|
|
||||||
|
|
||||||
default_vae_values = [default_vae_dict[x] for x in default_vae_list]
|
|
||||||
vae_dict = dict(default_vae_dict)
|
|
||||||
vae_list = list(default_vae_list)
|
|
||||||
first_load = True
|
|
||||||
|
|
||||||
|
|
||||||
base_vae = None
|
base_vae = None
|
||||||
|
@ -64,100 +50,69 @@ def restore_base_vae(model):
|
||||||
|
|
||||||
|
|
||||||
def get_filename(filepath):
|
def get_filename(filepath):
|
||||||
return os.path.splitext(os.path.basename(filepath))[0]
|
return os.path.basename(filepath)
|
||||||
|
|
||||||
|
|
||||||
def refresh_vae_list(vae_path=vae_path, model_path=model_path):
|
def refresh_vae_list():
|
||||||
global vae_dict, vae_list
|
vae_dict.clear()
|
||||||
res = {}
|
|
||||||
candidates = [
|
paths = [
|
||||||
*glob.iglob(os.path.join(model_path, '**/*.vae.ckpt'), recursive=True),
|
os.path.join(sd_models.model_path, '**/*.vae.ckpt'),
|
||||||
*glob.iglob(os.path.join(model_path, '**/*.vae.pt'), recursive=True),
|
os.path.join(sd_models.model_path, '**/*.vae.pt'),
|
||||||
*glob.iglob(os.path.join(model_path, '**/*.vae.safetensors'), recursive=True),
|
os.path.join(sd_models.model_path, '**/*.vae.safetensors'),
|
||||||
*glob.iglob(os.path.join(vae_path, '**/*.ckpt'), recursive=True),
|
os.path.join(vae_path, '**/*.ckpt'),
|
||||||
*glob.iglob(os.path.join(vae_path, '**/*.pt'), recursive=True),
|
os.path.join(vae_path, '**/*.pt'),
|
||||||
*glob.iglob(os.path.join(vae_path, '**/*.safetensors'), recursive=True),
|
os.path.join(vae_path, '**/*.safetensors'),
|
||||||
]
|
]
|
||||||
if shared.cmd_opts.vae_path is not None and os.path.isfile(shared.cmd_opts.vae_path):
|
|
||||||
candidates.append(shared.cmd_opts.vae_path)
|
if shared.cmd_opts.ckpt_dir is not None and os.path.isdir(shared.cmd_opts.ckpt_dir):
|
||||||
|
paths += [
|
||||||
|
os.path.join(shared.cmd_opts.ckpt_dir, '**/*.vae.ckpt'),
|
||||||
|
os.path.join(shared.cmd_opts.ckpt_dir, '**/*.vae.pt'),
|
||||||
|
os.path.join(shared.cmd_opts.ckpt_dir, '**/*.vae.safetensors'),
|
||||||
|
]
|
||||||
|
|
||||||
|
candidates = []
|
||||||
|
for path in paths:
|
||||||
|
candidates += glob.iglob(path, recursive=True)
|
||||||
|
|
||||||
for filepath in candidates:
|
for filepath in candidates:
|
||||||
name = get_filename(filepath)
|
name = get_filename(filepath)
|
||||||
res[name] = filepath
|
vae_dict[name] = filepath
|
||||||
vae_list.clear()
|
|
||||||
vae_list.extend(default_vae_list)
|
|
||||||
vae_list.extend(list(res.keys()))
|
|
||||||
vae_dict.clear()
|
|
||||||
vae_dict.update(res)
|
|
||||||
vae_dict.update(default_vae_dict)
|
|
||||||
return vae_list
|
|
||||||
|
|
||||||
|
|
||||||
def get_vae_from_settings(vae_file="auto"):
|
def find_vae_near_checkpoint(checkpoint_file):
|
||||||
# else, we load from settings, if not set to be default
|
checkpoint_path = os.path.splitext(checkpoint_file)[0]
|
||||||
if vae_file == "auto" and shared.opts.sd_vae is not None:
|
for vae_location in [checkpoint_path + ".vae.pt", checkpoint_path + ".vae.ckpt", checkpoint_path + ".vae.safetensors"]:
|
||||||
# if saved VAE settings isn't recognized, fallback to auto
|
if os.path.isfile(vae_location):
|
||||||
vae_file = vae_dict.get(shared.opts.sd_vae, "auto")
|
return vae_location
|
||||||
# if VAE selected but not found, fallback to auto
|
|
||||||
if vae_file not in default_vae_values and not os.path.isfile(vae_file):
|
return None
|
||||||
vae_file = "auto"
|
|
||||||
print(f"Selected VAE doesn't exist: {vae_file}")
|
|
||||||
return vae_file
|
|
||||||
|
|
||||||
|
|
||||||
def resolve_vae(checkpoint_file=None, vae_file="auto"):
|
def resolve_vae(checkpoint_file):
|
||||||
global first_load, vae_dict, vae_list
|
if shared.cmd_opts.vae_path is not None:
|
||||||
|
return shared.cmd_opts.vae_path, 'from commandline argument'
|
||||||
|
|
||||||
# if vae_file argument is provided, it takes priority, but not saved
|
vae_near_checkpoint = find_vae_near_checkpoint(checkpoint_file)
|
||||||
if vae_file and vae_file not in default_vae_list:
|
if vae_near_checkpoint is not None and (shared.opts.sd_vae_as_default or shared.opts.sd_vae == "Automatic"):
|
||||||
if not os.path.isfile(vae_file):
|
return vae_near_checkpoint, 'found near the checkpoint'
|
||||||
print(f"VAE provided as function argument doesn't exist: {vae_file}")
|
|
||||||
vae_file = "auto"
|
|
||||||
# for the first load, if vae-path is provided, it takes priority, saved, and failure is reported
|
|
||||||
if first_load and shared.cmd_opts.vae_path is not None:
|
|
||||||
if os.path.isfile(shared.cmd_opts.vae_path):
|
|
||||||
vae_file = shared.cmd_opts.vae_path
|
|
||||||
shared.opts.data['sd_vae'] = get_filename(vae_file)
|
|
||||||
else:
|
|
||||||
print(f"VAE provided as command line argument doesn't exist: {vae_file}")
|
|
||||||
# fallback to selector in settings, if vae selector not set to act as default fallback
|
|
||||||
if not shared.opts.sd_vae_as_default:
|
|
||||||
vae_file = get_vae_from_settings(vae_file)
|
|
||||||
# vae-path cmd arg takes priority for auto
|
|
||||||
if vae_file == "auto" and shared.cmd_opts.vae_path is not None:
|
|
||||||
if os.path.isfile(shared.cmd_opts.vae_path):
|
|
||||||
vae_file = shared.cmd_opts.vae_path
|
|
||||||
print(f"Using VAE provided as command line argument: {vae_file}")
|
|
||||||
# if still not found, try look for ".vae.pt" beside model
|
|
||||||
model_path = os.path.splitext(checkpoint_file)[0]
|
|
||||||
if vae_file == "auto":
|
|
||||||
vae_file_try = model_path + ".vae.pt"
|
|
||||||
if os.path.isfile(vae_file_try):
|
|
||||||
vae_file = vae_file_try
|
|
||||||
print(f"Using VAE found similar to selected model: {vae_file}")
|
|
||||||
# if still not found, try look for ".vae.ckpt" beside model
|
|
||||||
if vae_file == "auto":
|
|
||||||
vae_file_try = model_path + ".vae.ckpt"
|
|
||||||
if os.path.isfile(vae_file_try):
|
|
||||||
vae_file = vae_file_try
|
|
||||||
print(f"Using VAE found similar to selected model: {vae_file}")
|
|
||||||
# if still not found, try look for ".vae.safetensors" beside model
|
|
||||||
if vae_file == "auto":
|
|
||||||
vae_file_try = model_path + ".vae.safetensors"
|
|
||||||
if os.path.isfile(vae_file_try):
|
|
||||||
vae_file = vae_file_try
|
|
||||||
print(f"Using VAE found similar to selected model: {vae_file}")
|
|
||||||
# No more fallbacks for auto
|
|
||||||
if vae_file == "auto":
|
|
||||||
vae_file = None
|
|
||||||
# Last check, just because
|
|
||||||
if vae_file and not os.path.exists(vae_file):
|
|
||||||
vae_file = None
|
|
||||||
|
|
||||||
return vae_file
|
if shared.opts.sd_vae == "None":
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
vae_from_options = vae_dict.get(shared.opts.sd_vae, None)
|
||||||
|
if vae_from_options is not None:
|
||||||
|
return vae_from_options, 'specified in settings'
|
||||||
|
|
||||||
|
if shared.opts.sd_vae != "Automatic":
|
||||||
|
print(f"Couldn't find VAE named {shared.opts.sd_vae}; using None instead")
|
||||||
|
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
|
||||||
def load_vae(model, vae_file=None):
|
def load_vae(model, vae_file=None, vae_source="from unknown source"):
|
||||||
global first_load, vae_dict, vae_list, loaded_vae_file
|
global vae_dict, loaded_vae_file
|
||||||
# save_settings = False
|
# save_settings = False
|
||||||
|
|
||||||
cache_enabled = shared.opts.sd_vae_checkpoint_cache > 0
|
cache_enabled = shared.opts.sd_vae_checkpoint_cache > 0
|
||||||
|
@ -165,12 +120,12 @@ def load_vae(model, vae_file=None):
|
||||||
if vae_file:
|
if vae_file:
|
||||||
if cache_enabled and vae_file in checkpoints_loaded:
|
if cache_enabled and vae_file in checkpoints_loaded:
|
||||||
# use vae checkpoint cache
|
# use vae checkpoint cache
|
||||||
print(f"Loading VAE weights [{get_filename(vae_file)}] from cache")
|
print(f"Loading VAE weights {vae_source}: cached {get_filename(vae_file)}")
|
||||||
store_base_vae(model)
|
store_base_vae(model)
|
||||||
_load_vae_dict(model, checkpoints_loaded[vae_file])
|
_load_vae_dict(model, checkpoints_loaded[vae_file])
|
||||||
else:
|
else:
|
||||||
assert os.path.isfile(vae_file), f"VAE file doesn't exist: {vae_file}"
|
assert os.path.isfile(vae_file), f"VAE {vae_source} doesn't exist: {vae_file}"
|
||||||
print(f"Loading VAE weights from: {vae_file}")
|
print(f"Loading VAE weights {vae_source}: {vae_file}")
|
||||||
store_base_vae(model)
|
store_base_vae(model)
|
||||||
|
|
||||||
vae_ckpt = sd_models.read_state_dict(vae_file, map_location=shared.weight_load_location)
|
vae_ckpt = sd_models.read_state_dict(vae_file, map_location=shared.weight_load_location)
|
||||||
|
@ -191,14 +146,12 @@ def load_vae(model, vae_file=None):
|
||||||
vae_opt = get_filename(vae_file)
|
vae_opt = get_filename(vae_file)
|
||||||
if vae_opt not in vae_dict:
|
if vae_opt not in vae_dict:
|
||||||
vae_dict[vae_opt] = vae_file
|
vae_dict[vae_opt] = vae_file
|
||||||
vae_list.append(vae_opt)
|
|
||||||
elif loaded_vae_file:
|
elif loaded_vae_file:
|
||||||
restore_base_vae(model)
|
restore_base_vae(model)
|
||||||
|
|
||||||
loaded_vae_file = vae_file
|
loaded_vae_file = vae_file
|
||||||
|
|
||||||
first_load = False
|
|
||||||
|
|
||||||
|
|
||||||
# don't call this from outside
|
# don't call this from outside
|
||||||
def _load_vae_dict(model, vae_dict_1):
|
def _load_vae_dict(model, vae_dict_1):
|
||||||
|
@ -211,7 +164,10 @@ def clear_loaded_vae():
|
||||||
loaded_vae_file = None
|
loaded_vae_file = None
|
||||||
|
|
||||||
|
|
||||||
def reload_vae_weights(sd_model=None, vae_file="auto"):
|
unspecified = object()
|
||||||
|
|
||||||
|
|
||||||
|
def reload_vae_weights(sd_model=None, vae_file=unspecified):
|
||||||
from modules import lowvram, devices, sd_hijack
|
from modules import lowvram, devices, sd_hijack
|
||||||
|
|
||||||
if not sd_model:
|
if not sd_model:
|
||||||
|
@ -219,7 +175,11 @@ def reload_vae_weights(sd_model=None, vae_file="auto"):
|
||||||
|
|
||||||
checkpoint_info = sd_model.sd_checkpoint_info
|
checkpoint_info = sd_model.sd_checkpoint_info
|
||||||
checkpoint_file = checkpoint_info.filename
|
checkpoint_file = checkpoint_info.filename
|
||||||
vae_file = resolve_vae(checkpoint_file, vae_file=vae_file)
|
|
||||||
|
if vae_file == unspecified:
|
||||||
|
vae_file, vae_source = resolve_vae(checkpoint_file)
|
||||||
|
else:
|
||||||
|
vae_source = "from function argument"
|
||||||
|
|
||||||
if loaded_vae_file == vae_file:
|
if loaded_vae_file == vae_file:
|
||||||
return
|
return
|
||||||
|
@ -231,7 +191,7 @@ def reload_vae_weights(sd_model=None, vae_file="auto"):
|
||||||
|
|
||||||
sd_hijack.model_hijack.undo_hijack(sd_model)
|
sd_hijack.model_hijack.undo_hijack(sd_model)
|
||||||
|
|
||||||
load_vae(sd_model, vae_file)
|
load_vae(sd_model, vae_file, vae_source)
|
||||||
|
|
||||||
sd_hijack.model_hijack.hijack(sd_model)
|
sd_hijack.model_hijack.hijack(sd_model)
|
||||||
script_callbacks.model_loaded_callback(sd_model)
|
script_callbacks.model_loaded_callback(sd_model)
|
||||||
|
@ -239,5 +199,5 @@ def reload_vae_weights(sd_model=None, vae_file="auto"):
|
||||||
if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
|
if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
|
||||||
sd_model.to(devices.device)
|
sd_model.to(devices.device)
|
||||||
|
|
||||||
print("VAE Weights loaded.")
|
print("VAE weights loaded.")
|
||||||
return sd_model
|
return sd_model
|
||||||
|
|
|
@ -74,8 +74,8 @@ parser.add_argument("--freeze-settings", action='store_true', help="disable edit
|
||||||
parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default=os.path.join(script_path, 'config.json'))
|
parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default=os.path.join(script_path, 'config.json'))
|
||||||
parser.add_argument("--gradio-debug", action='store_true', help="launch gradio with --debug option")
|
parser.add_argument("--gradio-debug", action='store_true', help="launch gradio with --debug option")
|
||||||
parser.add_argument("--gradio-auth", type=str, help='set gradio authentication like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
|
parser.add_argument("--gradio-auth", type=str, help='set gradio authentication like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
|
||||||
parser.add_argument("--gradio-img2img-tool", type=str, help='gradio image uploader tool: can be either editor for ctopping, or color-sketch for drawing', choices=["color-sketch", "editor"], default="editor")
|
parser.add_argument("--gradio-img2img-tool", type=str, help='does not do anything')
|
||||||
parser.add_argument("--gradio-inpaint-tool", type=str, choices=["sketch", "color-sketch"], default="sketch", help="gradio inpainting editor: can be either sketch to only blur/noise the input, or color-sketch to paint over it")
|
parser.add_argument("--gradio-inpaint-tool", type=str, help="does not do anything")
|
||||||
parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last")
|
parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last")
|
||||||
parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(script_path, 'styles.csv'))
|
parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(script_path, 'styles.csv'))
|
||||||
parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False)
|
parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False)
|
||||||
|
@ -83,7 +83,7 @@ parser.add_argument("--theme", type=str, help="launches the UI with light or dar
|
||||||
parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False)
|
parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False)
|
||||||
parser.add_argument("--disable-console-progressbars", action='store_true', help="do not output progressbars to console", default=False)
|
parser.add_argument("--disable-console-progressbars", action='store_true', help="do not output progressbars to console", default=False)
|
||||||
parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False)
|
parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False)
|
||||||
parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencoders model', default=None)
|
parser.add_argument('--vae-path', type=str, help='Checkpoint to use as VAE; setting this argument disables all settings related to VAE', default=None)
|
||||||
parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False)
|
parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False)
|
||||||
parser.add_argument("--api", action='store_true', help="use api=True to launch the API together with the webui (use --nowebui instead for only the API)")
|
parser.add_argument("--api", action='store_true', help="use api=True to launch the API together with the webui (use --nowebui instead for only the API)")
|
||||||
parser.add_argument("--api-auth", type=str, help='Set authentication for API like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
|
parser.add_argument("--api-auth", type=str, help='Set authentication for API like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
|
||||||
|
@ -176,7 +176,7 @@ class State:
|
||||||
self.interrupted = True
|
self.interrupted = True
|
||||||
|
|
||||||
def nextjob(self):
|
def nextjob(self):
|
||||||
if opts.show_progress_every_n_steps == -1:
|
if opts.live_previews_enable and opts.show_progress_every_n_steps == -1:
|
||||||
self.do_set_current_image()
|
self.do_set_current_image()
|
||||||
|
|
||||||
self.job_no += 1
|
self.job_no += 1
|
||||||
|
@ -224,7 +224,7 @@ class State:
|
||||||
if not parallel_processing_allowed:
|
if not parallel_processing_allowed:
|
||||||
return
|
return
|
||||||
|
|
||||||
if self.sampling_step - self.current_image_sampling_step >= opts.show_progress_every_n_steps and opts.show_progress_every_n_steps > 0:
|
if self.sampling_step - self.current_image_sampling_step >= opts.show_progress_every_n_steps and opts.live_previews_enable:
|
||||||
self.do_set_current_image()
|
self.do_set_current_image()
|
||||||
|
|
||||||
def do_set_current_image(self):
|
def do_set_current_image(self):
|
||||||
|
@ -361,6 +361,7 @@ options_templates.update(options_section(('system', "System"), {
|
||||||
"memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation. Set to 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}),
|
"memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation. Set to 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}),
|
||||||
"samples_log_stdout": OptionInfo(False, "Always print all generation info to standard output"),
|
"samples_log_stdout": OptionInfo(False, "Always print all generation info to standard output"),
|
||||||
"multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."),
|
"multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."),
|
||||||
|
"print_hypernet_extra": OptionInfo(False, "Print extra hypernetwork information to console."),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('training', "Training"), {
|
options_templates.update(options_section(('training', "Training"), {
|
||||||
|
@ -373,13 +374,16 @@ options_templates.update(options_section(('training', "Training"), {
|
||||||
"training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
|
"training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
|
||||||
"training_write_csv_every": OptionInfo(500, "Save an csv containing the loss to log directory every N steps, 0 to disable"),
|
"training_write_csv_every": OptionInfo(500, "Save an csv containing the loss to log directory every N steps, 0 to disable"),
|
||||||
"training_xattention_optimizations": OptionInfo(False, "Use cross attention optimizations while training"),
|
"training_xattention_optimizations": OptionInfo(False, "Use cross attention optimizations while training"),
|
||||||
|
"training_enable_tensorboard": OptionInfo(False, "Enable tensorboard logging."),
|
||||||
|
"training_tensorboard_save_images": OptionInfo(False, "Save generated images within tensorboard."),
|
||||||
|
"training_tensorboard_flush_every": OptionInfo(120, "How often, in seconds, to flush the pending tensorboard events and summaries to disk."),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
||||||
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints),
|
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints),
|
||||||
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
||||||
"sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
"sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
||||||
"sd_vae": OptionInfo("auto", "SD VAE", gr.Dropdown, lambda: {"choices": sd_vae.vae_list}, refresh=sd_vae.refresh_vae_list),
|
"sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": ["Automatic", "None"] + list(sd_vae.vae_dict)}, refresh=sd_vae.refresh_vae_list),
|
||||||
"sd_vae_as_default": OptionInfo(False, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
|
"sd_vae_as_default": OptionInfo(False, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
|
||||||
"sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
|
"sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
|
||||||
"sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}),
|
"sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}),
|
||||||
|
@ -419,13 +423,11 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"),
|
||||||
|
|
||||||
options_templates.update(options_section(('ui', "User interface"), {
|
options_templates.update(options_section(('ui', "User interface"), {
|
||||||
"show_progressbar": OptionInfo(True, "Show progressbar"),
|
"show_progressbar": OptionInfo(True, "Show progressbar"),
|
||||||
"show_progress_every_n_steps": OptionInfo(0, "Show image creation progress every N sampling steps. Set to 0 to disable. Set to -1 to show after completion of batch.", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}),
|
|
||||||
"show_progress_type": OptionInfo("Full", "Image creation progress preview mode", gr.Radio, {"choices": ["Full", "Approx NN", "Approx cheap"]}),
|
|
||||||
"show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"),
|
"show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"),
|
||||||
"return_grid": OptionInfo(True, "Show grid in results for web"),
|
"return_grid": OptionInfo(True, "Show grid in results for web"),
|
||||||
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
|
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
|
||||||
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
|
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
|
||||||
"add_model_name_to_info": OptionInfo(False, "Add model name to generation information"),
|
"add_model_name_to_info": OptionInfo(True, "Add model name to generation information"),
|
||||||
"disable_weights_auto_swap": OptionInfo(False, "When reading generation parameters from text into UI (from PNG info or pasted text), do not change the selected model/checkpoint."),
|
"disable_weights_auto_swap": OptionInfo(False, "When reading generation parameters from text into UI (from PNG info or pasted text), do not change the selected model/checkpoint."),
|
||||||
"send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"),
|
"send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"),
|
||||||
"send_size": OptionInfo(True, "Send size when sending prompt or image to another interface"),
|
"send_size": OptionInfo(True, "Send size when sending prompt or image to another interface"),
|
||||||
|
@ -440,6 +442,13 @@ options_templates.update(options_section(('ui', "User interface"), {
|
||||||
'localization': OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)),
|
'localization': OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
options_templates.update(options_section(('ui', "Live previews"), {
|
||||||
|
"live_previews_enable": OptionInfo(True, "Show live previews of the created image"),
|
||||||
|
"show_progress_every_n_steps": OptionInfo(10, "Show new live preview image every N sampling steps. Set to -1 to show after completion of batch.", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}),
|
||||||
|
"show_progress_type": OptionInfo("Approx NN", "Image creation progress preview mode", gr.Radio, {"choices": ["Full", "Approx NN", "Approx cheap"]}),
|
||||||
|
"live_preview_content": OptionInfo("Prompt", "Live preview subject", gr.Radio, {"choices": ["Combined", "Prompt", "Negative prompt"]}),
|
||||||
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('sampler-params', "Sampler parameters"), {
|
options_templates.update(options_section(('sampler-params', "Sampler parameters"), {
|
||||||
"hide_samplers": OptionInfo([], "Hide samplers in user interface (requires restart)", gr.CheckboxGroup, lambda: {"choices": [x.name for x in list_samplers()]}),
|
"hide_samplers": OptionInfo([], "Hide samplers in user interface (requires restart)", gr.CheckboxGroup, lambda: {"choices": [x.name for x in list_samplers()]}),
|
||||||
"eta_ddim": OptionInfo(0.0, "eta (noise multiplier) for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
"eta_ddim": OptionInfo(0.0, "eta (noise multiplier) for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
||||||
|
@ -454,6 +463,7 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters"
|
||||||
|
|
||||||
options_templates.update(options_section((None, "Hidden options"), {
|
options_templates.update(options_section((None, "Hidden options"), {
|
||||||
"disabled_extensions": OptionInfo([], "Disable those extensions"),
|
"disabled_extensions": OptionInfo([], "Disable those extensions"),
|
||||||
|
"sd_checkpoint_hash": OptionInfo("", "SHA256 hash of the current checkpoint"),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update()
|
options_templates.update()
|
||||||
|
|
|
@ -3,8 +3,10 @@ import numpy as np
|
||||||
import PIL
|
import PIL
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from torch.utils.data import Dataset, DataLoader
|
from torch.utils.data import Dataset, DataLoader, Sampler
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
|
from collections import defaultdict
|
||||||
|
from random import shuffle, choices
|
||||||
|
|
||||||
import random
|
import random
|
||||||
import tqdm
|
import tqdm
|
||||||
|
@ -45,12 +47,12 @@ class PersonalizedBase(Dataset):
|
||||||
assert data_root, 'dataset directory not specified'
|
assert data_root, 'dataset directory not specified'
|
||||||
assert os.path.isdir(data_root), "Dataset directory doesn't exist"
|
assert os.path.isdir(data_root), "Dataset directory doesn't exist"
|
||||||
assert os.listdir(data_root), "Dataset directory is empty"
|
assert os.listdir(data_root), "Dataset directory is empty"
|
||||||
assert batch_size == 1 or not varsize, 'variable img size must have batch size 1'
|
|
||||||
|
|
||||||
self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]
|
self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]
|
||||||
|
|
||||||
self.shuffle_tags = shuffle_tags
|
self.shuffle_tags = shuffle_tags
|
||||||
self.tag_drop_out = tag_drop_out
|
self.tag_drop_out = tag_drop_out
|
||||||
|
groups = defaultdict(list)
|
||||||
|
|
||||||
print("Preparing dataset...")
|
print("Preparing dataset...")
|
||||||
for path in tqdm.tqdm(self.image_paths):
|
for path in tqdm.tqdm(self.image_paths):
|
||||||
|
@ -103,18 +105,25 @@ class PersonalizedBase(Dataset):
|
||||||
if include_cond and not (self.tag_drop_out != 0 or self.shuffle_tags):
|
if include_cond and not (self.tag_drop_out != 0 or self.shuffle_tags):
|
||||||
with devices.autocast():
|
with devices.autocast():
|
||||||
entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0)
|
entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0)
|
||||||
|
groups[image.size].append(len(self.dataset))
|
||||||
self.dataset.append(entry)
|
self.dataset.append(entry)
|
||||||
del torchdata
|
del torchdata
|
||||||
del latent_dist
|
del latent_dist
|
||||||
del latent_sample
|
del latent_sample
|
||||||
|
|
||||||
self.length = len(self.dataset)
|
self.length = len(self.dataset)
|
||||||
|
self.groups = list(groups.values())
|
||||||
assert self.length > 0, "No images have been found in the dataset."
|
assert self.length > 0, "No images have been found in the dataset."
|
||||||
self.batch_size = min(batch_size, self.length)
|
self.batch_size = min(batch_size, self.length)
|
||||||
self.gradient_step = min(gradient_step, self.length // self.batch_size)
|
self.gradient_step = min(gradient_step, self.length // self.batch_size)
|
||||||
self.latent_sampling_method = latent_sampling_method
|
self.latent_sampling_method = latent_sampling_method
|
||||||
|
|
||||||
|
if len(groups) > 1:
|
||||||
|
print("Buckets:")
|
||||||
|
for (w, h), ids in sorted(groups.items(), key=lambda x: x[0]):
|
||||||
|
print(f" {w}x{h}: {len(ids)}")
|
||||||
|
print()
|
||||||
|
|
||||||
def create_text(self, filename_text):
|
def create_text(self, filename_text):
|
||||||
text = random.choice(self.lines)
|
text = random.choice(self.lines)
|
||||||
tags = filename_text.split(',')
|
tags = filename_text.split(',')
|
||||||
|
@ -137,9 +146,44 @@ class PersonalizedBase(Dataset):
|
||||||
entry.latent_sample = shared.sd_model.get_first_stage_encoding(entry.latent_dist).to(devices.cpu)
|
entry.latent_sample = shared.sd_model.get_first_stage_encoding(entry.latent_dist).to(devices.cpu)
|
||||||
return entry
|
return entry
|
||||||
|
|
||||||
|
|
||||||
|
class GroupedBatchSampler(Sampler):
|
||||||
|
def __init__(self, data_source: PersonalizedBase, batch_size: int):
|
||||||
|
super().__init__(data_source)
|
||||||
|
|
||||||
|
n = len(data_source)
|
||||||
|
self.groups = data_source.groups
|
||||||
|
self.len = n_batch = n // batch_size
|
||||||
|
expected = [len(g) / n * n_batch * batch_size for g in data_source.groups]
|
||||||
|
self.base = [int(e) // batch_size for e in expected]
|
||||||
|
self.n_rand_batches = nrb = n_batch - sum(self.base)
|
||||||
|
self.probs = [e%batch_size/nrb/batch_size if nrb>0 else 0 for e in expected]
|
||||||
|
self.batch_size = batch_size
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.len
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
b = self.batch_size
|
||||||
|
|
||||||
|
for g in self.groups:
|
||||||
|
shuffle(g)
|
||||||
|
|
||||||
|
batches = []
|
||||||
|
for g in self.groups:
|
||||||
|
batches.extend(g[i*b:(i+1)*b] for i in range(len(g) // b))
|
||||||
|
for _ in range(self.n_rand_batches):
|
||||||
|
rand_group = choices(self.groups, self.probs)[0]
|
||||||
|
batches.append(choices(rand_group, k=b))
|
||||||
|
|
||||||
|
shuffle(batches)
|
||||||
|
|
||||||
|
yield from batches
|
||||||
|
|
||||||
|
|
||||||
class PersonalizedDataLoader(DataLoader):
|
class PersonalizedDataLoader(DataLoader):
|
||||||
def __init__(self, dataset, latent_sampling_method="once", batch_size=1, pin_memory=False):
|
def __init__(self, dataset, latent_sampling_method="once", batch_size=1, pin_memory=False):
|
||||||
super(PersonalizedDataLoader, self).__init__(dataset, shuffle=True, drop_last=True, batch_size=batch_size, pin_memory=pin_memory)
|
super(PersonalizedDataLoader, self).__init__(dataset, batch_sampler=GroupedBatchSampler(dataset, batch_size), pin_memory=pin_memory)
|
||||||
if latent_sampling_method == "random":
|
if latent_sampling_method == "random":
|
||||||
self.collate_fn = collate_wrapper_random
|
self.collate_fn = collate_wrapper_random
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -76,10 +76,10 @@ def insert_image_data_embed(image, data):
|
||||||
next_size = data_np_low.shape[0] + (h-(data_np_low.shape[0] % h))
|
next_size = data_np_low.shape[0] + (h-(data_np_low.shape[0] % h))
|
||||||
next_size = next_size + ((h*d)-(next_size % (h*d)))
|
next_size = next_size + ((h*d)-(next_size % (h*d)))
|
||||||
|
|
||||||
data_np_low.resize(next_size)
|
data_np_low = np.resize(data_np_low, next_size)
|
||||||
data_np_low = data_np_low.reshape((h, -1, d))
|
data_np_low = data_np_low.reshape((h, -1, d))
|
||||||
|
|
||||||
data_np_high.resize(next_size)
|
data_np_high = np.resize(data_np_high, next_size)
|
||||||
data_np_high = data_np_high.reshape((h, -1, d))
|
data_np_high = data_np_high.reshape((h, -1, d))
|
||||||
|
|
||||||
edge_style = list(data['string_to_param'].values())[0].cpu().detach().numpy().tolist()[0][:1024]
|
edge_style = list(data['string_to_param'].values())[0].cpu().detach().numpy().tolist()[0][:1024]
|
||||||
|
|
|
@ -2,7 +2,7 @@ import datetime
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
|
||||||
saved_params_shared = {"model_name", "model_hash", "initial_step", "num_of_dataset_images", "learn_rate", "batch_size", "data_root", "log_directory", "training_width", "training_height", "steps", "create_image_every", "template_file"}
|
saved_params_shared = {"model_name", "model_hash", "initial_step", "num_of_dataset_images", "learn_rate", "batch_size", "clip_grad_mode", "clip_grad_value", "gradient_step", "data_root", "log_directory", "training_width", "training_height", "steps", "create_image_every", "template_file"}
|
||||||
saved_params_ti = {"embedding_name", "num_vectors_per_token", "save_embedding_every", "save_image_with_stored_embedding"}
|
saved_params_ti = {"embedding_name", "num_vectors_per_token", "save_embedding_every", "save_image_with_stored_embedding"}
|
||||||
saved_params_hypernet = {"hypernetwork_name", "layer_structure", "activation_func", "weight_init", "add_layer_norm", "use_dropout", "save_hypernetwork_every"}
|
saved_params_hypernet = {"hypernetwork_name", "layer_structure", "activation_func", "weight_init", "add_layer_norm", "use_dropout", "save_hypernetwork_every"}
|
||||||
saved_params_all = saved_params_shared | saved_params_ti | saved_params_hypernet
|
saved_params_all = saved_params_shared | saved_params_ti | saved_params_hypernet
|
||||||
|
|
|
@ -135,7 +135,8 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre
|
||||||
params.process_caption_deepbooru = process_caption_deepbooru
|
params.process_caption_deepbooru = process_caption_deepbooru
|
||||||
params.preprocess_txt_action = preprocess_txt_action
|
params.preprocess_txt_action = preprocess_txt_action
|
||||||
|
|
||||||
for index, imagefile in enumerate(tqdm.tqdm(files)):
|
pbar = tqdm.tqdm(files)
|
||||||
|
for index, imagefile in enumerate(pbar):
|
||||||
params.subindex = 0
|
params.subindex = 0
|
||||||
filename = os.path.join(src, imagefile)
|
filename = os.path.join(src, imagefile)
|
||||||
try:
|
try:
|
||||||
|
@ -143,6 +144,10 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre
|
||||||
except Exception:
|
except Exception:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
description = f"Preprocessing [Image {index}/{len(files)}]"
|
||||||
|
pbar.set_description(description)
|
||||||
|
shared.state.textinfo = description
|
||||||
|
|
||||||
params.src = filename
|
params.src = filename
|
||||||
|
|
||||||
existing_caption = None
|
existing_caption = None
|
||||||
|
|
|
@ -9,8 +9,11 @@ import tqdm
|
||||||
import html
|
import html
|
||||||
import datetime
|
import datetime
|
||||||
import csv
|
import csv
|
||||||
|
import safetensors.torch
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
from PIL import Image, PngImagePlugin
|
from PIL import Image, PngImagePlugin
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers
|
from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers
|
||||||
import modules.textual_inversion.dataset
|
import modules.textual_inversion.dataset
|
||||||
|
@ -150,6 +153,8 @@ class EmbeddingDatabase:
|
||||||
name = data.get('name', name)
|
name = data.get('name', name)
|
||||||
elif ext in ['.BIN', '.PT']:
|
elif ext in ['.BIN', '.PT']:
|
||||||
data = torch.load(path, map_location="cpu")
|
data = torch.load(path, map_location="cpu")
|
||||||
|
elif ext in ['.SAFETENSORS']:
|
||||||
|
data = safetensors.torch.load_file(path, device="cpu")
|
||||||
else:
|
else:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -245,9 +250,12 @@ def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'):
|
||||||
with devices.autocast():
|
with devices.autocast():
|
||||||
cond_model([""]) # will send cond model to GPU if lowvram/medvram is active
|
cond_model([""]) # will send cond model to GPU if lowvram/medvram is active
|
||||||
|
|
||||||
embedded = cond_model.encode_embedding_init_text(init_text, num_vectors_per_token)
|
#cond_model expects at least some text, so we provide '*' as backup.
|
||||||
|
embedded = cond_model.encode_embedding_init_text(init_text or '*', num_vectors_per_token)
|
||||||
vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device)
|
vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device)
|
||||||
|
|
||||||
|
#Only copy if we provided an init_text, otherwise keep vectors as zeros
|
||||||
|
if init_text:
|
||||||
for i in range(num_vectors_per_token):
|
for i in range(num_vectors_per_token):
|
||||||
vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token]
|
vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token]
|
||||||
|
|
||||||
|
@ -288,6 +296,30 @@ def write_loss(log_directory, filename, step, epoch_len, values):
|
||||||
**values,
|
**values,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
def tensorboard_setup(log_directory):
|
||||||
|
os.makedirs(os.path.join(log_directory, "tensorboard"), exist_ok=True)
|
||||||
|
return SummaryWriter(
|
||||||
|
log_dir=os.path.join(log_directory, "tensorboard"),
|
||||||
|
flush_secs=shared.opts.training_tensorboard_flush_every)
|
||||||
|
|
||||||
|
def tensorboard_add(tensorboard_writer, loss, global_step, step, learn_rate, epoch_num):
|
||||||
|
tensorboard_add_scaler(tensorboard_writer, "Loss/train", loss, global_step)
|
||||||
|
tensorboard_add_scaler(tensorboard_writer, f"Loss/train/epoch-{epoch_num}", loss, step)
|
||||||
|
tensorboard_add_scaler(tensorboard_writer, "Learn rate/train", learn_rate, global_step)
|
||||||
|
tensorboard_add_scaler(tensorboard_writer, f"Learn rate/train/epoch-{epoch_num}", learn_rate, step)
|
||||||
|
|
||||||
|
def tensorboard_add_scaler(tensorboard_writer, tag, value, step):
|
||||||
|
tensorboard_writer.add_scalar(tag=tag,
|
||||||
|
scalar_value=value, global_step=step)
|
||||||
|
|
||||||
|
def tensorboard_add_image(tensorboard_writer, tag, pil_image, step):
|
||||||
|
# Convert a pil image to a torch tensor
|
||||||
|
img_tensor = torch.as_tensor(np.array(pil_image, copy=True))
|
||||||
|
img_tensor = img_tensor.view(pil_image.size[1], pil_image.size[0],
|
||||||
|
len(pil_image.getbands()))
|
||||||
|
img_tensor = img_tensor.permute((2, 0, 1))
|
||||||
|
|
||||||
|
tensorboard_writer.add_image(tag, img_tensor, global_step=step)
|
||||||
|
|
||||||
def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_model_every, create_image_every, log_directory, name="embedding"):
|
def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_model_every, create_image_every, log_directory, name="embedding"):
|
||||||
assert model_name, f"{name} not selected"
|
assert model_name, f"{name} not selected"
|
||||||
|
@ -367,12 +399,15 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
|
||||||
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
||||||
old_parallel_processing_allowed = shared.parallel_processing_allowed
|
old_parallel_processing_allowed = shared.parallel_processing_allowed
|
||||||
|
|
||||||
|
if shared.opts.training_enable_tensorboard:
|
||||||
|
tensorboard_writer = tensorboard_setup(log_directory)
|
||||||
|
|
||||||
pin_memory = shared.opts.pin_memory
|
pin_memory = shared.opts.pin_memory
|
||||||
|
|
||||||
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method, varsize=varsize)
|
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method, varsize=varsize)
|
||||||
|
|
||||||
if shared.opts.save_training_settings_to_txt:
|
if shared.opts.save_training_settings_to_txt:
|
||||||
save_settings_to_file(log_directory, {**dict(model_name=checkpoint.model_name, model_hash=checkpoint.hash, num_of_dataset_images=len(ds), num_vectors_per_token=len(embedding.vec)), **locals()})
|
save_settings_to_file(log_directory, {**dict(model_name=checkpoint.model_name, model_hash=checkpoint.shorthash, num_of_dataset_images=len(ds), num_vectors_per_token=len(embedding.vec)), **locals()})
|
||||||
|
|
||||||
latent_sampling_method = ds.latent_sampling_method
|
latent_sampling_method = ds.latent_sampling_method
|
||||||
|
|
||||||
|
@ -473,7 +508,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
|
||||||
epoch_num = embedding.step // steps_per_epoch
|
epoch_num = embedding.step // steps_per_epoch
|
||||||
epoch_step = embedding.step % steps_per_epoch
|
epoch_step = embedding.step % steps_per_epoch
|
||||||
|
|
||||||
pbar.set_description(f"[Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}")
|
description = f"Training textual inversion [Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}] loss: {loss_step:.7f}"
|
||||||
|
pbar.set_description(description)
|
||||||
|
shared.state.textinfo = description
|
||||||
if embedding_dir is not None and steps_done % save_embedding_every == 0:
|
if embedding_dir is not None and steps_done % save_embedding_every == 0:
|
||||||
# Before saving, change name to match current checkpoint.
|
# Before saving, change name to match current checkpoint.
|
||||||
embedding_name_every = f'{embedding_name}-{steps_done}'
|
embedding_name_every = f'{embedding_name}-{steps_done}'
|
||||||
|
@ -527,6 +564,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
|
||||||
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
|
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
|
||||||
last_saved_image += f", prompt: {preview_text}"
|
last_saved_image += f", prompt: {preview_text}"
|
||||||
|
|
||||||
|
if shared.opts.training_enable_tensorboard and shared.opts.training_tensorboard_save_images:
|
||||||
|
tensorboard_add_image(tensorboard_writer, f"Validation at epoch {epoch_num}", image, embedding.step)
|
||||||
|
|
||||||
if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:
|
if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:
|
||||||
|
|
||||||
last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{steps_done}.png')
|
last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{steps_done}.png')
|
||||||
|
@ -544,7 +584,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
|
||||||
|
|
||||||
checkpoint = sd_models.select_checkpoint()
|
checkpoint = sd_models.select_checkpoint()
|
||||||
footer_left = checkpoint.model_name
|
footer_left = checkpoint.model_name
|
||||||
footer_mid = '[{}]'.format(checkpoint.hash)
|
footer_mid = '[{}]'.format(checkpoint.shorthash)
|
||||||
footer_right = '{}v {}s'.format(vectorSize, steps_done)
|
footer_right = '{}v {}s'.format(vectorSize, steps_done)
|
||||||
|
|
||||||
captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right)
|
captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right)
|
||||||
|
@ -586,7 +626,7 @@ def save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, r
|
||||||
old_sd_checkpoint_name = embedding.sd_checkpoint_name if hasattr(embedding, "sd_checkpoint_name") else None
|
old_sd_checkpoint_name = embedding.sd_checkpoint_name if hasattr(embedding, "sd_checkpoint_name") else None
|
||||||
old_cached_checksum = embedding.cached_checksum if hasattr(embedding, "cached_checksum") else None
|
old_cached_checksum = embedding.cached_checksum if hasattr(embedding, "cached_checksum") else None
|
||||||
try:
|
try:
|
||||||
embedding.sd_checkpoint = checkpoint.hash
|
embedding.sd_checkpoint = checkpoint.shorthash
|
||||||
embedding.sd_checkpoint_name = checkpoint.model_name
|
embedding.sd_checkpoint_name = checkpoint.model_name
|
||||||
if remove_cached_checksum:
|
if remove_cached_checksum:
|
||||||
embedding.cached_checksum = None
|
embedding.cached_checksum = None
|
||||||
|
|
|
@ -795,17 +795,20 @@ def create_ui():
|
||||||
|
|
||||||
with FormRow().style(equal_height=False):
|
with FormRow().style(equal_height=False):
|
||||||
with gr.Column(variant='panel', elem_id="img2img_settings"):
|
with gr.Column(variant='panel', elem_id="img2img_settings"):
|
||||||
|
with gr.Tabs(elem_id="mode_img2img"):
|
||||||
|
with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab") as tab_img2img:
|
||||||
|
init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool="editor", image_mode="RGBA").style(height=480)
|
||||||
|
|
||||||
with gr.Tabs(elem_id="mode_img2img") as tabs_img2img_mode:
|
with gr.TabItem('Sketch', id='img2img_sketch', elem_id="img2img_img2img_sketch_tab") as tab_sketch:
|
||||||
with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab"):
|
sketch = gr.Image(label="Image for img2img", elem_id="img2img_sketch", show_label=False, source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA").style(height=480)
|
||||||
init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_img2img_tool, image_mode="RGBA").style(height=480)
|
|
||||||
|
|
||||||
with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab"):
|
with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab") as tab_inpaint:
|
||||||
init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_inpaint_tool, image_mode="RGBA").style(height=480)
|
init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA").style(height=480)
|
||||||
init_img_with_mask_orig = gr.State(None)
|
|
||||||
|
with gr.TabItem('Inpaint sketch', id='inpaint_sketch', elem_id="img2img_inpaint_sketch_tab") as tab_inpaint_color:
|
||||||
|
inpaint_color_sketch = gr.Image(label="Color sketch inpainting", show_label=False, elem_id="inpaint_sketch", source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA").style(height=480)
|
||||||
|
inpaint_color_sketch_orig = gr.State(None)
|
||||||
|
|
||||||
use_color_sketch = cmd_opts.gradio_inpaint_tool == "color-sketch"
|
|
||||||
if use_color_sketch:
|
|
||||||
def update_orig(image, state):
|
def update_orig(image, state):
|
||||||
if image is not None:
|
if image is not None:
|
||||||
same_size = state is not None and state.size == image.size
|
same_size = state is not None and state.size == image.size
|
||||||
|
@ -813,17 +816,24 @@ def create_ui():
|
||||||
edited = same_size and has_exact_match
|
edited = same_size and has_exact_match
|
||||||
return image if not edited or state is None else state
|
return image if not edited or state is None else state
|
||||||
|
|
||||||
init_img_with_mask.change(update_orig, [init_img_with_mask, init_img_with_mask_orig], init_img_with_mask_orig)
|
inpaint_color_sketch.change(update_orig, [inpaint_color_sketch, inpaint_color_sketch_orig], inpaint_color_sketch_orig)
|
||||||
|
|
||||||
init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_base")
|
with gr.TabItem('Inpaint upload', id='inpaint_upload', elem_id="img2img_inpaint_upload_tab") as tab_inpaint_upload:
|
||||||
init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_mask")
|
init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", elem_id="img_inpaint_base")
|
||||||
|
init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", elem_id="img_inpaint_mask")
|
||||||
|
|
||||||
|
with gr.TabItem('Batch', id='batch', elem_id="img2img_batch_tab") as tab_batch:
|
||||||
|
hidden = '<br>Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else ''
|
||||||
|
gr.HTML(f"<p class=\"text-gray-500\">Process images in a directory on the same machine where the server is running.<br>Use an empty output directory to save pictures normally instead of writing to the output directory.{hidden}</p>")
|
||||||
|
img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, elem_id="img2img_batch_input_dir")
|
||||||
|
img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir")
|
||||||
|
|
||||||
|
with FormGroup(elem_id="inpaint_controls", visible=False) as inpaint_controls:
|
||||||
with FormRow():
|
with FormRow():
|
||||||
mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id="img2img_mask_blur")
|
mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id="img2img_mask_blur")
|
||||||
mask_alpha = gr.Slider(label="Mask transparency", interactive=use_color_sketch, visible=use_color_sketch, elem_id="img2img_mask_alpha")
|
mask_alpha = gr.Slider(label="Mask transparency", visible=False, elem_id="img2img_mask_alpha")
|
||||||
|
|
||||||
with FormRow():
|
with FormRow():
|
||||||
mask_mode = gr.Radio(label="Mask source", choices=["Draw mask", "Upload mask"], type="index", value="Draw mask", elem_id="mask_mode")
|
|
||||||
inpainting_mask_invert = gr.Radio(label='Mask mode', choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", elem_id="img2img_mask_mode")
|
inpainting_mask_invert = gr.Radio(label='Mask mode', choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", elem_id="img2img_mask_mode")
|
||||||
|
|
||||||
with FormRow():
|
with FormRow():
|
||||||
|
@ -836,11 +846,15 @@ def create_ui():
|
||||||
with gr.Column(scale=4):
|
with gr.Column(scale=4):
|
||||||
inpaint_full_res_padding = gr.Slider(label='Only masked padding, pixels', minimum=0, maximum=256, step=4, value=32, elem_id="img2img_inpaint_full_res_padding")
|
inpaint_full_res_padding = gr.Slider(label='Only masked padding, pixels', minimum=0, maximum=256, step=4, value=32, elem_id="img2img_inpaint_full_res_padding")
|
||||||
|
|
||||||
with gr.TabItem('Batch img2img', id='batch', elem_id="img2img_batch_tab"):
|
def select_img2img_tab(tab):
|
||||||
hidden = '<br>Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else ''
|
return gr.update(visible=tab in [2, 3, 4]), gr.update(visible=tab == 3),
|
||||||
gr.HTML(f"<p class=\"text-gray-500\">Process images in a directory on the same machine where the server is running.<br>Use an empty output directory to save pictures normally instead of writing to the output directory.{hidden}</p>")
|
|
||||||
img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, elem_id="img2img_batch_input_dir")
|
for i, elem in enumerate([tab_img2img, tab_sketch, tab_inpaint, tab_inpaint_color, tab_inpaint_upload, tab_batch]):
|
||||||
img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir")
|
elem.select(
|
||||||
|
fn=lambda tab=i: select_img2img_tab(tab),
|
||||||
|
inputs=[],
|
||||||
|
outputs=[inpaint_controls, mask_alpha],
|
||||||
|
)
|
||||||
|
|
||||||
with FormRow():
|
with FormRow():
|
||||||
resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize")
|
resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize")
|
||||||
|
@ -900,20 +914,6 @@ def create_ui():
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
mask_mode.change(
|
|
||||||
lambda mode, img: {
|
|
||||||
init_img_with_mask: gr_show(mode == 0),
|
|
||||||
init_img_inpaint: gr_show(mode == 1),
|
|
||||||
init_mask_inpaint: gr_show(mode == 1),
|
|
||||||
},
|
|
||||||
inputs=[mask_mode, init_img_with_mask],
|
|
||||||
outputs=[
|
|
||||||
init_img_with_mask,
|
|
||||||
init_img_inpaint,
|
|
||||||
init_mask_inpaint,
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
img2img_args = dict(
|
img2img_args = dict(
|
||||||
fn=wrap_gradio_gpu_call(modules.img2img.img2img, extra_outputs=[None, '', '']),
|
fn=wrap_gradio_gpu_call(modules.img2img.img2img, extra_outputs=[None, '', '']),
|
||||||
_js="submit_img2img",
|
_js="submit_img2img",
|
||||||
|
@ -924,11 +924,12 @@ def create_ui():
|
||||||
img2img_prompt_style,
|
img2img_prompt_style,
|
||||||
img2img_prompt_style2,
|
img2img_prompt_style2,
|
||||||
init_img,
|
init_img,
|
||||||
|
sketch,
|
||||||
init_img_with_mask,
|
init_img_with_mask,
|
||||||
init_img_with_mask_orig,
|
inpaint_color_sketch,
|
||||||
|
inpaint_color_sketch_orig,
|
||||||
init_img_inpaint,
|
init_img_inpaint,
|
||||||
init_mask_inpaint,
|
init_mask_inpaint,
|
||||||
mask_mode,
|
|
||||||
steps,
|
steps,
|
||||||
sampler_index,
|
sampler_index,
|
||||||
mask_blur,
|
mask_blur,
|
||||||
|
@ -1129,7 +1130,7 @@ def create_ui():
|
||||||
with gr.Column(variant='panel'):
|
with gr.Column(variant='panel'):
|
||||||
gr.HTML(value="<p>A merger of the two checkpoints will be generated in your <b>checkpoint</b> directory.</p>")
|
gr.HTML(value="<p>A merger of the two checkpoints will be generated in your <b>checkpoint</b> directory.</p>")
|
||||||
|
|
||||||
with gr.Row():
|
with FormRow():
|
||||||
primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)")
|
primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)")
|
||||||
create_refresh_button(primary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_A")
|
create_refresh_button(primary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_A")
|
||||||
|
|
||||||
|
@ -1143,11 +1144,13 @@ def create_ui():
|
||||||
interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3, elem_id="modelmerger_interp_amount")
|
interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3, elem_id="modelmerger_interp_amount")
|
||||||
interp_method = gr.Radio(choices=["Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method", elem_id="modelmerger_interp_method")
|
interp_method = gr.Radio(choices=["Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method", elem_id="modelmerger_interp_method")
|
||||||
|
|
||||||
with gr.Row():
|
with FormRow():
|
||||||
checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="ckpt", label="Checkpoint format", elem_id="modelmerger_checkpoint_format")
|
checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="ckpt", label="Checkpoint format", elem_id="modelmerger_checkpoint_format")
|
||||||
save_as_half = gr.Checkbox(value=False, label="Save as float16", elem_id="modelmerger_save_as_half")
|
save_as_half = gr.Checkbox(value=False, label="Save as float16", elem_id="modelmerger_save_as_half")
|
||||||
|
|
||||||
modelmerger_merge = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary')
|
config_source = gr.Radio(choices=["A, B or C", "B", "C", "Don't"], value="A, B or C", label="Copy config from", type="index", elem_id="modelmerger_config_method")
|
||||||
|
|
||||||
|
modelmerger_merge = gr.Button(elem_id="modelmerger_merge", value="Merge", variant='primary')
|
||||||
|
|
||||||
with gr.Column(variant='panel'):
|
with gr.Column(variant='panel'):
|
||||||
submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False)
|
submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False)
|
||||||
|
@ -1703,6 +1706,7 @@ def create_ui():
|
||||||
save_as_half,
|
save_as_half,
|
||||||
custom_name,
|
custom_name,
|
||||||
checkpoint_format,
|
checkpoint_format,
|
||||||
|
config_source,
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
submit_result,
|
submit_result,
|
||||||
|
@ -1837,4 +1841,6 @@ xformers: {xformers_version}
|
||||||
gradio: {gr.__version__}
|
gradio: {gr.__version__}
|
||||||
•
|
•
|
||||||
commit: <a href="https://github.com/AUTOMATIC1111/stable-diffusion-webui/commit/{commit}">{short_commit}</a>
|
commit: <a href="https://github.com/AUTOMATIC1111/stable-diffusion-webui/commit/{commit}">{short_commit}</a>
|
||||||
|
•
|
||||||
|
checkpoint: <a id="sd_checkpoint_hash">N/A</a>
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -52,7 +52,7 @@ def check_progress_call(id_part):
|
||||||
image = gr.update(visible=False)
|
image = gr.update(visible=False)
|
||||||
preview_visibility = gr.update(visible=False)
|
preview_visibility = gr.update(visible=False)
|
||||||
|
|
||||||
if opts.show_progress_every_n_steps != 0:
|
if opts.live_previews_enable:
|
||||||
shared.state.set_current_image()
|
shared.state.set_current_image()
|
||||||
image = shared.state.current_image
|
image = shared.state.current_image
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
function gradioApp() {
|
function gradioApp() {
|
||||||
const gradioShadowRoot = document.getElementsByTagName('gradio-app')[0].shadowRoot
|
const elems = document.getElementsByTagName('gradio-app')
|
||||||
|
const gradioShadowRoot = elems.length == 0 ? null : elems[0].shadowRoot
|
||||||
return !!gradioShadowRoot ? gradioShadowRoot : document;
|
return !!gradioShadowRoot ? gradioShadowRoot : document;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -13,6 +14,7 @@ function get_uiCurrentTabContent() {
|
||||||
|
|
||||||
uiUpdateCallbacks = []
|
uiUpdateCallbacks = []
|
||||||
uiTabChangeCallbacks = []
|
uiTabChangeCallbacks = []
|
||||||
|
optionsChangedCallbacks = []
|
||||||
let uiCurrentTab = null
|
let uiCurrentTab = null
|
||||||
|
|
||||||
function onUiUpdate(callback){
|
function onUiUpdate(callback){
|
||||||
|
@ -21,6 +23,9 @@ function onUiUpdate(callback){
|
||||||
function onUiTabChange(callback){
|
function onUiTabChange(callback){
|
||||||
uiTabChangeCallbacks.push(callback)
|
uiTabChangeCallbacks.push(callback)
|
||||||
}
|
}
|
||||||
|
function onOptionsChanged(callback){
|
||||||
|
optionsChangedCallbacks.push(callback)
|
||||||
|
}
|
||||||
|
|
||||||
function runCallback(x, m){
|
function runCallback(x, m){
|
||||||
try {
|
try {
|
||||||
|
|
|
@ -146,11 +146,7 @@ class Script(scripts.Script):
|
||||||
else:
|
else:
|
||||||
args = {"prompt": line}
|
args = {"prompt": line}
|
||||||
|
|
||||||
n_iter = args.get("n_iter", 1)
|
job_count += args.get("n_iter", p.n_iter)
|
||||||
if n_iter != 1:
|
|
||||||
job_count += n_iter
|
|
||||||
else:
|
|
||||||
job_count += 1
|
|
||||||
|
|
||||||
jobs.append(args)
|
jobs.append(args)
|
||||||
|
|
||||||
|
|
|
@ -125,24 +125,21 @@ def apply_upscale_latent_space(p, x, xs):
|
||||||
|
|
||||||
|
|
||||||
def find_vae(name: str):
|
def find_vae(name: str):
|
||||||
if name.lower() in ['auto', 'none']:
|
if name.lower() in ['auto', 'automatic']:
|
||||||
return name
|
return modules.sd_vae.unspecified
|
||||||
|
if name.lower() == 'none':
|
||||||
|
return None
|
||||||
else:
|
else:
|
||||||
vae_path = os.path.abspath(os.path.join(paths.models_path, 'VAE'))
|
choices = [x for x in sorted(modules.sd_vae.vae_dict, key=lambda x: len(x)) if name.lower().strip() in x.lower()]
|
||||||
found = glob.glob(os.path.join(vae_path, f'**/{name}.*pt'), recursive=True)
|
if len(choices) == 0:
|
||||||
if found:
|
print(f"No VAE found for {name}; using automatic")
|
||||||
return found[0]
|
return modules.sd_vae.unspecified
|
||||||
else:
|
else:
|
||||||
return 'auto'
|
return modules.sd_vae.vae_dict[choices[0]]
|
||||||
|
|
||||||
|
|
||||||
def apply_vae(p, x, xs):
|
def apply_vae(p, x, xs):
|
||||||
if x.lower().strip() == 'none':
|
modules.sd_vae.reload_vae_weights(shared.sd_model, vae_file=find_vae(x))
|
||||||
modules.sd_vae.reload_vae_weights(shared.sd_model, vae_file='None')
|
|
||||||
else:
|
|
||||||
found = find_vae(x)
|
|
||||||
if found:
|
|
||||||
v = modules.sd_vae.reload_vae_weights(shared.sd_model, vae_file=found)
|
|
||||||
|
|
||||||
|
|
||||||
def apply_styles(p: StableDiffusionProcessingTxt2Img, x: str, _):
|
def apply_styles(p: StableDiffusionProcessingTxt2Img, x: str, _):
|
||||||
|
@ -271,7 +268,9 @@ class SharedSettingsStackHelper(object):
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_value, tb):
|
def __exit__(self, exc_type, exc_value, tb):
|
||||||
modules.sd_models.reload_model_weights(self.model)
|
modules.sd_models.reload_model_weights(self.model)
|
||||||
modules.sd_vae.reload_vae_weights(self.model, vae_file=find_vae(self.vae))
|
|
||||||
|
opts.data["sd_vae"] = self.vae
|
||||||
|
modules.sd_vae.reload_vae_weights(self.model)
|
||||||
|
|
||||||
hypernetwork.load_hypernetwork(self.hypernetwork)
|
hypernetwork.load_hypernetwork(self.hypernetwork)
|
||||||
hypernetwork.apply_strength()
|
hypernetwork.apply_strength()
|
||||||
|
|
|
@ -557,7 +557,9 @@ canvas[key="mask"] {
|
||||||
}
|
}
|
||||||
|
|
||||||
#img2img_image, #img2img_image > .h-60, #img2img_image > .h-60 > div, #img2img_image > .h-60 > div > img,
|
#img2img_image, #img2img_image > .h-60, #img2img_image > .h-60 > div, #img2img_image > .h-60 > div > img,
|
||||||
img2maskimg, #img2maskimg > .h-60, #img2maskimg > .h-60 > div, #img2maskimg > .h-60 > div > img
|
#img2img_sketch, #img2img_sketch > .h-60, #img2img_sketch > .h-60 > div, #img2img_sketch > .h-60 > div > img,
|
||||||
|
#img2maskimg, #img2maskimg > .h-60, #img2maskimg > .h-60 > div, #img2maskimg > .h-60 > div > img,
|
||||||
|
#inpaint_sketch, #inpaint_sketch > .h-60, #inpaint_sketch > .h-60 > div, #inpaint_sketch > .h-60 > div > img
|
||||||
{
|
{
|
||||||
height: 480px !important;
|
height: 480px !important;
|
||||||
max-height: 480px !important;
|
max-height: 480px !important;
|
||||||
|
|
2
webui.py
2
webui.py
|
@ -78,6 +78,8 @@ def initialize():
|
||||||
print("Stable diffusion model failed to load, exiting", file=sys.stderr)
|
print("Stable diffusion model failed to load, exiting", file=sys.stderr)
|
||||||
exit(1)
|
exit(1)
|
||||||
|
|
||||||
|
shared.opts.data["sd_model_checkpoint"] = shared.sd_model.sd_checkpoint_info.title
|
||||||
|
|
||||||
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()))
|
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()))
|
||||||
shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
|
shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
|
||||||
shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
|
shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user