Compare commits
2 Commits
Author | SHA1 | Date | |
---|---|---|---|
|
7cb31a278e | ||
|
2abd89acc6 |
|
@ -20,7 +20,8 @@ model:
|
||||||
conditioning_key: hybrid
|
conditioning_key: hybrid
|
||||||
monitor: val/loss_simple_ema
|
monitor: val/loss_simple_ema
|
||||||
scale_factor: 0.18215
|
scale_factor: 0.18215
|
||||||
use_ema: false
|
use_ema: true
|
||||||
|
load_ema: true
|
||||||
|
|
||||||
scheduler_config: # 10000 warmup steps
|
scheduler_config: # 10000 warmup steps
|
||||||
target: ldm.lr_scheduler.LambdaLinearScheduler
|
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||||
|
|
|
@ -12,7 +12,7 @@ re_unet_up_blocks = re.compile(r"lora_unet_up_blocks_(\d+)_attentions_(\d+)_(.+)
|
||||||
re_text_block = re.compile(r"lora_te_text_model_encoder_layers_(\d+)_(.+)")
|
re_text_block = re.compile(r"lora_te_text_model_encoder_layers_(\d+)_(.+)")
|
||||||
|
|
||||||
|
|
||||||
def convert_diffusers_name_to_compvis(key):
|
def convert_diffusers_name_to_compvis(key, is_sd2):
|
||||||
def match(match_list, regex):
|
def match(match_list, regex):
|
||||||
r = re.match(regex, key)
|
r = re.match(regex, key)
|
||||||
if not r:
|
if not r:
|
||||||
|
@ -34,6 +34,14 @@ def convert_diffusers_name_to_compvis(key):
|
||||||
return f"diffusion_model_output_blocks_{m[0] * 3 + m[1]}_1_{m[2]}"
|
return f"diffusion_model_output_blocks_{m[0] * 3 + m[1]}_1_{m[2]}"
|
||||||
|
|
||||||
if match(m, re_text_block):
|
if match(m, re_text_block):
|
||||||
|
if is_sd2:
|
||||||
|
if 'mlp_fc1' in m[1]:
|
||||||
|
return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}"
|
||||||
|
elif 'mlp_fc2' in m[1]:
|
||||||
|
return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}"
|
||||||
|
else:
|
||||||
|
return f"model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}"
|
||||||
|
|
||||||
return f"transformer_text_model_encoder_layers_{m[0]}_{m[1]}"
|
return f"transformer_text_model_encoder_layers_{m[0]}_{m[1]}"
|
||||||
|
|
||||||
return key
|
return key
|
||||||
|
@ -83,9 +91,10 @@ def load_lora(name, filename):
|
||||||
sd = sd_models.read_state_dict(filename)
|
sd = sd_models.read_state_dict(filename)
|
||||||
|
|
||||||
keys_failed_to_match = []
|
keys_failed_to_match = []
|
||||||
|
is_sd2 = 'model_transformer_resblocks' in shared.sd_model.lora_layer_mapping
|
||||||
|
|
||||||
for key_diffusers, weight in sd.items():
|
for key_diffusers, weight in sd.items():
|
||||||
fullkey = convert_diffusers_name_to_compvis(key_diffusers)
|
fullkey = convert_diffusers_name_to_compvis(key_diffusers, is_sd2)
|
||||||
key, lora_key = fullkey.split(".", 1)
|
key, lora_key = fullkey.split(".", 1)
|
||||||
|
|
||||||
sd_module = shared.sd_model.lora_layer_mapping.get(key, None)
|
sd_module = shared.sd_model.lora_layer_mapping.get(key, None)
|
||||||
|
@ -104,9 +113,13 @@ def load_lora(name, filename):
|
||||||
|
|
||||||
if type(sd_module) == torch.nn.Linear:
|
if type(sd_module) == torch.nn.Linear:
|
||||||
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
|
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
|
||||||
|
elif type(sd_module) == torch.nn.modules.linear.NonDynamicallyQuantizableLinear:
|
||||||
|
module = torch.nn.modules.linear.NonDynamicallyQuantizableLinear(weight.shape[1], weight.shape[0], bias=False)
|
||||||
elif type(sd_module) == torch.nn.Conv2d:
|
elif type(sd_module) == torch.nn.Conv2d:
|
||||||
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
|
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
|
||||||
else:
|
else:
|
||||||
|
print(f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}')
|
||||||
|
continue
|
||||||
assert False, f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}'
|
assert False, f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}'
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
@ -182,6 +195,10 @@ def lora_Conv2d_forward(self, input):
|
||||||
return lora_forward(self, input, torch.nn.Conv2d_forward_before_lora(self, input))
|
return lora_forward(self, input, torch.nn.Conv2d_forward_before_lora(self, input))
|
||||||
|
|
||||||
|
|
||||||
|
def lora_NonDynamicallyQuantizableLinear_forward(self, input):
|
||||||
|
return lora_forward(self, input, torch.nn.NonDynamicallyQuantizableLinear_forward_before_lora(self, input))
|
||||||
|
|
||||||
|
|
||||||
def list_available_loras():
|
def list_available_loras():
|
||||||
available_loras.clear()
|
available_loras.clear()
|
||||||
|
|
||||||
|
|
|
@ -10,6 +10,7 @@ from modules import script_callbacks, ui_extra_networks, extra_networks, shared
|
||||||
def unload():
|
def unload():
|
||||||
torch.nn.Linear.forward = torch.nn.Linear_forward_before_lora
|
torch.nn.Linear.forward = torch.nn.Linear_forward_before_lora
|
||||||
torch.nn.Conv2d.forward = torch.nn.Conv2d_forward_before_lora
|
torch.nn.Conv2d.forward = torch.nn.Conv2d_forward_before_lora
|
||||||
|
torch.nn.modules.linear.NonDynamicallyQuantizableLinear.forward = torch.nn.NonDynamicallyQuantizableLinear_forward_before_lora
|
||||||
|
|
||||||
|
|
||||||
def before_ui():
|
def before_ui():
|
||||||
|
@ -23,8 +24,12 @@ if not hasattr(torch.nn, 'Linear_forward_before_lora'):
|
||||||
if not hasattr(torch.nn, 'Conv2d_forward_before_lora'):
|
if not hasattr(torch.nn, 'Conv2d_forward_before_lora'):
|
||||||
torch.nn.Conv2d_forward_before_lora = torch.nn.Conv2d.forward
|
torch.nn.Conv2d_forward_before_lora = torch.nn.Conv2d.forward
|
||||||
|
|
||||||
|
if not hasattr(torch.nn, 'NonDynamicallyQuantizableLinear_forward_before_lora'):
|
||||||
|
torch.nn.NonDynamicallyQuantizableLinear_forward_before_lora = torch.nn.modules.linear.NonDynamicallyQuantizableLinear.forward
|
||||||
|
|
||||||
torch.nn.Linear.forward = lora.lora_Linear_forward
|
torch.nn.Linear.forward = lora.lora_Linear_forward
|
||||||
torch.nn.Conv2d.forward = lora.lora_Conv2d_forward
|
torch.nn.Conv2d.forward = lora.lora_Conv2d_forward
|
||||||
|
torch.nn.modules.linear.NonDynamicallyQuantizableLinear.forward = lora.lora_NonDynamicallyQuantizableLinear_forward
|
||||||
|
|
||||||
script_callbacks.on_model_loaded(lora.assign_lora_names_to_compvis_modules)
|
script_callbacks.on_model_loaded(lora.assign_lora_names_to_compvis_modules)
|
||||||
script_callbacks.on_script_unloaded(unload)
|
script_callbacks.on_script_unloaded(unload)
|
||||||
|
|
|
@ -20,14 +20,13 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
|
||||||
preview = None
|
preview = None
|
||||||
for file in previews:
|
for file in previews:
|
||||||
if os.path.isfile(file):
|
if os.path.isfile(file):
|
||||||
preview = self.link_preview(file)
|
preview = "./file=" + file.replace('\\', '/') + "?mtime=" + str(os.path.getmtime(file))
|
||||||
break
|
break
|
||||||
|
|
||||||
yield {
|
yield {
|
||||||
"name": name,
|
"name": name,
|
||||||
"filename": path,
|
"filename": path,
|
||||||
"preview": preview,
|
"preview": preview,
|
||||||
"search_term": self.search_terms_from_path(lora_on_disk.filename),
|
|
||||||
"prompt": json.dumps(f"<lora:{name}:") + " + opts.extra_networks_default_multiplier + " + json.dumps(">"),
|
"prompt": json.dumps(f"<lora:{name}:") + " + opts.extra_networks_default_multiplier + " + json.dumps(">"),
|
||||||
"local_preview": path + ".png",
|
"local_preview": path + ".png",
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,7 +4,6 @@
|
||||||
<ul>
|
<ul>
|
||||||
<a href="#" title="replace preview image with currently selected in gallery" onclick={save_card_preview}>replace preview</a>
|
<a href="#" title="replace preview image with currently selected in gallery" onclick={save_card_preview}>replace preview</a>
|
||||||
</ul>
|
</ul>
|
||||||
<span style="display:none" class='search_term'>{search_term}</span>
|
|
||||||
</div>
|
</div>
|
||||||
<span class='name'>{name}</span>
|
<span class='name'>{name}</span>
|
||||||
</div>
|
</div>
|
||||||
|
|
|
@ -16,7 +16,7 @@ function setupExtraNetworksForTab(tabname){
|
||||||
searchTerm = search.value.toLowerCase()
|
searchTerm = search.value.toLowerCase()
|
||||||
|
|
||||||
gradioApp().querySelectorAll('#'+tabname+'_extra_tabs div.card').forEach(function(elem){
|
gradioApp().querySelectorAll('#'+tabname+'_extra_tabs div.card').forEach(function(elem){
|
||||||
text = elem.querySelector('.name').textContent.toLowerCase() + " " + elem.querySelector('.search_term').textContent.toLowerCase()
|
text = elem.querySelector('.name').textContent.toLowerCase()
|
||||||
elem.style.display = text.indexOf(searchTerm) == -1 ? "none" : ""
|
elem.style.display = text.indexOf(searchTerm) == -1 ? "none" : ""
|
||||||
})
|
})
|
||||||
});
|
});
|
||||||
|
@ -48,39 +48,10 @@ function setupExtraNetworks(){
|
||||||
|
|
||||||
onUiLoaded(setupExtraNetworks)
|
onUiLoaded(setupExtraNetworks)
|
||||||
|
|
||||||
var re_extranet = /<([^:]+:[^:]+):[\d\.]+>/;
|
|
||||||
var re_extranet_g = /\s+<([^:]+:[^:]+):[\d\.]+>/g;
|
|
||||||
|
|
||||||
function tryToRemoveExtraNetworkFromPrompt(textarea, text){
|
|
||||||
var m = text.match(re_extranet)
|
|
||||||
if(! m) return false
|
|
||||||
|
|
||||||
var partToSearch = m[1]
|
|
||||||
var replaced = false
|
|
||||||
var newTextareaText = textarea.value.replaceAll(re_extranet_g, function(found, index){
|
|
||||||
m = found.match(re_extranet);
|
|
||||||
if(m[1] == partToSearch){
|
|
||||||
replaced = true;
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
return found;
|
|
||||||
})
|
|
||||||
|
|
||||||
if(replaced){
|
|
||||||
textarea.value = newTextareaText
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
function cardClicked(tabname, textToAdd, allowNegativePrompt){
|
function cardClicked(tabname, textToAdd, allowNegativePrompt){
|
||||||
var textarea = allowNegativePrompt ? activePromptTextarea[tabname] : gradioApp().querySelector("#" + tabname + "_prompt > label > textarea")
|
var textarea = allowNegativePrompt ? activePromptTextarea[tabname] : gradioApp().querySelector("#" + tabname + "_prompt > label > textarea")
|
||||||
|
|
||||||
if(! tryToRemoveExtraNetworkFromPrompt(textarea, textToAdd)){
|
textarea.value = textarea.value + " " + textToAdd
|
||||||
textarea.value = textarea.value + " " + textToAdd
|
|
||||||
}
|
|
||||||
|
|
||||||
updateInput(textarea)
|
updateInput(textarea)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -96,12 +67,3 @@ function saveCardPreview(event, tabname, filename){
|
||||||
event.stopPropagation()
|
event.stopPropagation()
|
||||||
event.preventDefault()
|
event.preventDefault()
|
||||||
}
|
}
|
||||||
|
|
||||||
function extraNetworksSearchButton(tabs_id, event){
|
|
||||||
searchTextarea = gradioApp().querySelector("#" + tabs_id + ' > div > textarea')
|
|
||||||
button = event.target
|
|
||||||
text = button.classList.contains("search-all") ? "" : button.textContent.trim()
|
|
||||||
|
|
||||||
searchTextarea.value = text
|
|
||||||
updateInput(searchTextarea)
|
|
||||||
}
|
|
|
@ -17,7 +17,7 @@ titles = {
|
||||||
"\u2199\ufe0f": "Read generation parameters from prompt or last generation if prompt is empty into user interface.",
|
"\u2199\ufe0f": "Read generation parameters from prompt or last generation if prompt is empty into user interface.",
|
||||||
"\u{1f4c2}": "Open images output directory",
|
"\u{1f4c2}": "Open images output directory",
|
||||||
"\u{1f4be}": "Save style",
|
"\u{1f4be}": "Save style",
|
||||||
"\u{1f5d1}": "Clear prompt",
|
"\U0001F5D1": "Clear prompt",
|
||||||
"\u{1f4cb}": "Apply selected styles to current prompt",
|
"\u{1f4cb}": "Apply selected styles to current prompt",
|
||||||
"\u{1f4d2}": "Paste available values into the field",
|
"\u{1f4d2}": "Paste available values into the field",
|
||||||
"\u{1f3b4}": "Show extra networks",
|
"\u{1f3b4}": "Show extra networks",
|
||||||
|
@ -66,8 +66,8 @@ titles = {
|
||||||
|
|
||||||
"Interrogate": "Reconstruct prompt from existing image and put it into the prompt field.",
|
"Interrogate": "Reconstruct prompt from existing image and put it into the prompt field.",
|
||||||
|
|
||||||
"Images filename pattern": "Use following tags to define how filenames for images are chosen: [steps], [cfg], [prompt_hash], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [model_name], [prompt_words], [date], [datetime], [datetime<Format>], [datetime<Format><Time Zone>], [job_timestamp]; leave empty for default.",
|
"Images filename pattern": "Use following tags to define how filenames for images are chosen: [steps], [cfg], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [model_name], [prompt_words], [date], [datetime], [datetime<Format>], [datetime<Format><Time Zone>], [job_timestamp]; leave empty for default.",
|
||||||
"Directory name pattern": "Use following tags to define how subdirectories for images and grids are chosen: [steps], [cfg],[prompt_hash], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [model_name], [prompt_words], [date], [datetime], [datetime<Format>], [datetime<Format><Time Zone>], [job_timestamp]; leave empty for default.",
|
"Directory name pattern": "Use following tags to define how subdirectories for images and grids are chosen: [steps], [cfg], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [model_name], [prompt_words], [date], [datetime], [datetime<Format>], [datetime<Format><Time Zone>], [job_timestamp]; leave empty for default.",
|
||||||
"Max prompt words": "Set the maximum number of words to be used in the [prompt_words] option; ATTENTION: If the words are too long, they may exceed the maximum length of the file path that the system can handle",
|
"Max prompt words": "Set the maximum number of words to be used in the [prompt_words] option; ATTENTION: If the words are too long, they may exceed the maximum length of the file path that the system can handle",
|
||||||
|
|
||||||
"Loopback": "Process an image, use it as an input, repeat.",
|
"Loopback": "Process an image, use it as an input, repeat.",
|
||||||
|
|
|
@ -191,28 +191,6 @@ function confirm_clear_prompt(prompt, negative_prompt) {
|
||||||
return [prompt, negative_prompt]
|
return [prompt, negative_prompt]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
promptTokecountUpdateFuncs = {}
|
|
||||||
|
|
||||||
function recalculatePromptTokens(name){
|
|
||||||
if(promptTokecountUpdateFuncs[name]){
|
|
||||||
promptTokecountUpdateFuncs[name]()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
function recalculate_prompts_txt2img(){
|
|
||||||
recalculatePromptTokens('txt2img_prompt')
|
|
||||||
recalculatePromptTokens('txt2img_neg_prompt')
|
|
||||||
return args_to_array(arguments);
|
|
||||||
}
|
|
||||||
|
|
||||||
function recalculate_prompts_img2img(){
|
|
||||||
recalculatePromptTokens('img2img_prompt')
|
|
||||||
recalculatePromptTokens('img2img_neg_prompt')
|
|
||||||
return args_to_array(arguments);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
opts = {}
|
opts = {}
|
||||||
onUiUpdate(function(){
|
onUiUpdate(function(){
|
||||||
if(Object.keys(opts).length != 0) return;
|
if(Object.keys(opts).length != 0) return;
|
||||||
|
@ -254,12 +232,14 @@ onUiUpdate(function(){
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
prompt.parentElement.insertBefore(counter, prompt)
|
prompt.parentElement.insertBefore(counter, prompt)
|
||||||
counter.classList.add("token-counter")
|
counter.classList.add("token-counter")
|
||||||
prompt.parentElement.style.position = "relative"
|
prompt.parentElement.style.position = "relative"
|
||||||
|
|
||||||
promptTokecountUpdateFuncs[id] = function(){ update_token_counter(id_button); }
|
textarea.addEventListener("input", function(){
|
||||||
textarea.addEventListener("input", promptTokecountUpdateFuncs[id]);
|
update_token_counter(id_button);
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
registerTextarea('txt2img_prompt', 'txt2img_token_counter', 'txt2img_token_button')
|
registerTextarea('txt2img_prompt', 'txt2img_token_counter', 'txt2img_token_button')
|
||||||
|
@ -293,7 +273,7 @@ onOptionsChanged(function(){
|
||||||
|
|
||||||
let txt2img_textarea, img2img_textarea = undefined;
|
let txt2img_textarea, img2img_textarea = undefined;
|
||||||
let wait_time = 800
|
let wait_time = 800
|
||||||
let token_timeouts = {};
|
let token_timeout;
|
||||||
|
|
||||||
function update_txt2img_tokens(...args) {
|
function update_txt2img_tokens(...args) {
|
||||||
update_token_counter("txt2img_token_button")
|
update_token_counter("txt2img_token_button")
|
||||||
|
@ -310,9 +290,9 @@ function update_img2img_tokens(...args) {
|
||||||
}
|
}
|
||||||
|
|
||||||
function update_token_counter(button_id) {
|
function update_token_counter(button_id) {
|
||||||
if (token_timeouts[button_id])
|
if (token_timeout)
|
||||||
clearTimeout(token_timeouts[button_id]);
|
clearTimeout(token_timeout);
|
||||||
token_timeouts[button_id] = setTimeout(() => gradioApp().getElementById(button_id)?.click(), wait_time);
|
token_timeout = setTimeout(() => gradioApp().getElementById(button_id)?.click(), wait_time);
|
||||||
}
|
}
|
||||||
|
|
||||||
function restart_reload(){
|
function restart_reload(){
|
||||||
|
@ -329,10 +309,3 @@ function updateInput(target){
|
||||||
Object.defineProperty(e, "target", {value: target})
|
Object.defineProperty(e, "target", {value: target})
|
||||||
target.dispatchEvent(e);
|
target.dispatchEvent(e);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
var desiredCheckpointName = null;
|
|
||||||
function selectCheckpoint(name){
|
|
||||||
desiredCheckpointName = name;
|
|
||||||
gradioApp().getElementById('change_checkpoint').click()
|
|
||||||
}
|
|
||||||
|
|
|
@ -223,7 +223,6 @@ def prepare_environment():
|
||||||
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
|
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
|
||||||
commandline_args = os.environ.get('COMMANDLINE_ARGS', "")
|
commandline_args = os.environ.get('COMMANDLINE_ARGS', "")
|
||||||
|
|
||||||
xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.16rc425')
|
|
||||||
gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "git+https://github.com/TencentARC/GFPGAN.git@8d2447a2d918f8eba5a4a01463fd48e45126a379")
|
gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "git+https://github.com/TencentARC/GFPGAN.git@8d2447a2d918f8eba5a4a01463fd48e45126a379")
|
||||||
clip_package = os.environ.get('CLIP_PACKAGE', "git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1")
|
clip_package = os.environ.get('CLIP_PACKAGE', "git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1")
|
||||||
openclip_package = os.environ.get('OPENCLIP_PACKAGE', "git+https://github.com/mlfoundations/open_clip.git@bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b")
|
openclip_package = os.environ.get('OPENCLIP_PACKAGE', "git+https://github.com/mlfoundations/open_clip.git@bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b")
|
||||||
|
@ -283,14 +282,14 @@ def prepare_environment():
|
||||||
if (not is_installed("xformers") or reinstall_xformers) and xformers:
|
if (not is_installed("xformers") or reinstall_xformers) and xformers:
|
||||||
if platform.system() == "Windows":
|
if platform.system() == "Windows":
|
||||||
if platform.python_version().startswith("3.10"):
|
if platform.python_version().startswith("3.10"):
|
||||||
run_pip(f"install -U -I --no-deps {xformers_package}", "xformers")
|
run_pip(f"install -U -I --no-deps xformers==0.0.16rc425", "xformers")
|
||||||
else:
|
else:
|
||||||
print("Installation of xformers is not supported in this version of Python.")
|
print("Installation of xformers is not supported in this version of Python.")
|
||||||
print("You can also check this and build manually: https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers#building-xformers-on-windows-by-duckness")
|
print("You can also check this and build manually: https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers#building-xformers-on-windows-by-duckness")
|
||||||
if not is_installed("xformers"):
|
if not is_installed("xformers"):
|
||||||
exit(0)
|
exit(0)
|
||||||
elif platform.system() == "Linux":
|
elif platform.system() == "Linux":
|
||||||
run_pip(f"install {xformers_package}", "xformers")
|
run_pip("install xformers==0.0.16rc425", "xformers")
|
||||||
|
|
||||||
if not is_installed("pyngrok") and ngrok:
|
if not is_installed("pyngrok") and ngrok:
|
||||||
run_pip("install pyngrok", "ngrok")
|
run_pip("install pyngrok", "ngrok")
|
||||||
|
|
|
@ -16,10 +16,6 @@ def has_mps() -> bool:
|
||||||
except Exception:
|
except Exception:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def has_dml():
|
|
||||||
import importlib
|
|
||||||
loader = importlib.find_loader('torch_directml')
|
|
||||||
return loader is not None
|
|
||||||
|
|
||||||
def extract_device_id(args, name):
|
def extract_device_id(args, name):
|
||||||
for x in range(len(args)):
|
for x in range(len(args)):
|
||||||
|
@ -39,23 +35,16 @@ def get_cuda_device_string():
|
||||||
|
|
||||||
|
|
||||||
def get_optimal_device_name():
|
def get_optimal_device_name():
|
||||||
if has_dml():
|
if torch.cuda.is_available():
|
||||||
return "dml"
|
return get_cuda_device_string()
|
||||||
|
|
||||||
if has_mps():
|
if has_mps():
|
||||||
return "mps"
|
return "mps"
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
return get_cuda_device_string()
|
|
||||||
|
|
||||||
return "cpu"
|
return "cpu"
|
||||||
|
|
||||||
|
|
||||||
def get_optimal_device():
|
def get_optimal_device():
|
||||||
if get_optimal_device_name() == "dml":
|
|
||||||
import torch_directml
|
|
||||||
return torch_directml.device()
|
|
||||||
|
|
||||||
return torch.device(get_optimal_device_name())
|
return torch.device(get_optimal_device_name())
|
||||||
|
|
||||||
|
|
||||||
|
@ -218,22 +207,3 @@ if has_mps():
|
||||||
cumsum_needs_bool_fix = not torch.BoolTensor([True,True]).to(device=torch.device("mps"), dtype=torch.int64).equal(torch.BoolTensor([True,False]).to(torch.device("mps")).cumsum(0))
|
cumsum_needs_bool_fix = not torch.BoolTensor([True,True]).to(device=torch.device("mps"), dtype=torch.int64).equal(torch.BoolTensor([True,False]).to(torch.device("mps")).cumsum(0))
|
||||||
torch.cumsum = lambda input, *args, **kwargs: ( cumsum_fix(input, orig_cumsum, *args, **kwargs) )
|
torch.cumsum = lambda input, *args, **kwargs: ( cumsum_fix(input, orig_cumsum, *args, **kwargs) )
|
||||||
torch.Tensor.cumsum = lambda self, *args, **kwargs: ( cumsum_fix(self, orig_Tensor_cumsum, *args, **kwargs) )
|
torch.Tensor.cumsum = lambda self, *args, **kwargs: ( cumsum_fix(self, orig_Tensor_cumsum, *args, **kwargs) )
|
||||||
|
|
||||||
if has_dml():
|
|
||||||
_cumsum = torch.cumsum
|
|
||||||
_repeat_interleave = torch.repeat_interleave
|
|
||||||
_multinomial = torch.multinomial
|
|
||||||
|
|
||||||
_Tensor_new = torch.Tensor.new
|
|
||||||
_Tensor_cumsum = torch.Tensor.cumsum
|
|
||||||
_Tensor_repeat_interleave = torch.Tensor.repeat_interleave
|
|
||||||
_Tensor_multinomial = torch.Tensor.multinomial
|
|
||||||
|
|
||||||
torch.cumsum = lambda input, *args, **kwargs: ( _cumsum(input.to("cpu"), *args, **kwargs).to(input.device) )
|
|
||||||
torch.repeat_interleave = lambda input, *args, **kwargs: ( _repeat_interleave(input.to("cpu"), *args, **kwargs).to(input.device) )
|
|
||||||
torch.multinomial = lambda input, *args, **kwargs: ( _multinomial(input.to("cpu"), *args, **kwargs).to(input.device) )
|
|
||||||
|
|
||||||
torch.Tensor.new = lambda self, *args, **kwargs: ( _Tensor_new(self.to("cpu"), *args, **kwargs).to(self.device) )
|
|
||||||
torch.Tensor.cumsum = lambda self, *args, **kwargs: ( _Tensor_cumsum(self.to("cpu"), *args, **kwargs).to(self.device) )
|
|
||||||
torch.Tensor.repeat_interleave = lambda self, *args, **kwargs: ( _Tensor_repeat_interleave(self.to("cpu"), *args, **kwargs).to(self.device) )
|
|
||||||
torch.Tensor.multinomial = lambda self, *args, **kwargs: ( _Tensor_multinomial(self.to("cpu"), *args, **kwargs).to(self.device) )
|
|
|
@ -1,5 +1,4 @@
|
||||||
import base64
|
import base64
|
||||||
import html
|
|
||||||
import io
|
import io
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
|
@ -17,23 +16,13 @@ re_param = re.compile(re_param_code)
|
||||||
re_imagesize = re.compile(r"^(\d+)x(\d+)$")
|
re_imagesize = re.compile(r"^(\d+)x(\d+)$")
|
||||||
re_hypernet_hash = re.compile("\(([0-9a-f]+)\)$")
|
re_hypernet_hash = re.compile("\(([0-9a-f]+)\)$")
|
||||||
type_of_gr_update = type(gr.update())
|
type_of_gr_update = type(gr.update())
|
||||||
|
|
||||||
paste_fields = {}
|
paste_fields = {}
|
||||||
registered_param_bindings = []
|
bind_list = []
|
||||||
|
|
||||||
|
|
||||||
class ParamBinding:
|
|
||||||
def __init__(self, paste_button, tabname, source_text_component=None, source_image_component=None, source_tabname=None, override_settings_component=None):
|
|
||||||
self.paste_button = paste_button
|
|
||||||
self.tabname = tabname
|
|
||||||
self.source_text_component = source_text_component
|
|
||||||
self.source_image_component = source_image_component
|
|
||||||
self.source_tabname = source_tabname
|
|
||||||
self.override_settings_component = override_settings_component
|
|
||||||
|
|
||||||
|
|
||||||
def reset():
|
def reset():
|
||||||
paste_fields.clear()
|
paste_fields.clear()
|
||||||
|
bind_list.clear()
|
||||||
|
|
||||||
|
|
||||||
def quote(text):
|
def quote(text):
|
||||||
|
@ -85,6 +74,26 @@ def add_paste_fields(tabname, init_img, fields):
|
||||||
modules.ui.img2img_paste_fields = fields
|
modules.ui.img2img_paste_fields = fields
|
||||||
|
|
||||||
|
|
||||||
|
def integrate_settings_paste_fields(component_dict):
|
||||||
|
from modules import ui
|
||||||
|
|
||||||
|
settings_map = {
|
||||||
|
'CLIP_stop_at_last_layers': 'Clip skip',
|
||||||
|
'inpainting_mask_weight': 'Conditional mask weight',
|
||||||
|
'sd_model_checkpoint': 'Model hash',
|
||||||
|
'eta_noise_seed_delta': 'ENSD',
|
||||||
|
'initial_noise_multiplier': 'Noise multiplier',
|
||||||
|
}
|
||||||
|
settings_paste_fields = [
|
||||||
|
(component_dict[k], lambda d, k=k, v=v: ui.apply_setting(k, d.get(v, None)))
|
||||||
|
for k, v in settings_map.items()
|
||||||
|
]
|
||||||
|
|
||||||
|
for tabname, info in paste_fields.items():
|
||||||
|
if info["fields"] is not None:
|
||||||
|
info["fields"] += settings_paste_fields
|
||||||
|
|
||||||
|
|
||||||
def create_buttons(tabs_list):
|
def create_buttons(tabs_list):
|
||||||
buttons = {}
|
buttons = {}
|
||||||
for tab in tabs_list:
|
for tab in tabs_list:
|
||||||
|
@ -92,60 +101,9 @@ def create_buttons(tabs_list):
|
||||||
return buttons
|
return buttons
|
||||||
|
|
||||||
|
|
||||||
|
#if send_generate_info is a tab name, mean generate_info comes from the params fields of the tab
|
||||||
def bind_buttons(buttons, send_image, send_generate_info):
|
def bind_buttons(buttons, send_image, send_generate_info):
|
||||||
"""old function for backwards compatibility; do not use this, use register_paste_params_button"""
|
bind_list.append([buttons, send_image, send_generate_info])
|
||||||
for tabname, button in buttons.items():
|
|
||||||
source_text_component = send_generate_info if isinstance(send_generate_info, gr.components.Component) else None
|
|
||||||
source_tabname = send_generate_info if isinstance(send_generate_info, str) else None
|
|
||||||
|
|
||||||
register_paste_params_button(ParamBinding(paste_button=button, tabname=tabname, source_text_component=source_text_component, source_image_component=send_image, source_tabname=source_tabname))
|
|
||||||
|
|
||||||
|
|
||||||
def register_paste_params_button(binding: ParamBinding):
|
|
||||||
registered_param_bindings.append(binding)
|
|
||||||
|
|
||||||
|
|
||||||
def connect_paste_params_buttons():
|
|
||||||
binding: ParamBinding
|
|
||||||
for binding in registered_param_bindings:
|
|
||||||
destination_image_component = paste_fields[binding.tabname]["init_img"]
|
|
||||||
fields = paste_fields[binding.tabname]["fields"]
|
|
||||||
|
|
||||||
destination_width_component = next(iter([field for field, name in fields if name == "Size-1"] if fields else []), None)
|
|
||||||
destination_height_component = next(iter([field for field, name in fields if name == "Size-2"] if fields else []), None)
|
|
||||||
|
|
||||||
if binding.source_image_component and destination_image_component:
|
|
||||||
if isinstance(binding.source_image_component, gr.Gallery):
|
|
||||||
func = send_image_and_dimensions if destination_width_component else image_from_url_text
|
|
||||||
jsfunc = "extract_image_from_gallery"
|
|
||||||
else:
|
|
||||||
func = send_image_and_dimensions if destination_width_component else lambda x: x
|
|
||||||
jsfunc = None
|
|
||||||
|
|
||||||
binding.paste_button.click(
|
|
||||||
fn=func,
|
|
||||||
_js=jsfunc,
|
|
||||||
inputs=[binding.source_image_component],
|
|
||||||
outputs=[destination_image_component, destination_width_component, destination_height_component] if destination_width_component else [destination_image_component],
|
|
||||||
)
|
|
||||||
|
|
||||||
if binding.source_text_component is not None and fields is not None:
|
|
||||||
connect_paste(binding.paste_button, fields, binding.source_text_component, binding.override_settings_component, binding.tabname)
|
|
||||||
|
|
||||||
if binding.source_tabname is not None and fields is not None:
|
|
||||||
paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration'] + (["Seed"] if shared.opts.send_seed else [])
|
|
||||||
binding.paste_button.click(
|
|
||||||
fn=lambda *x: x,
|
|
||||||
inputs=[field for field, name in paste_fields[binding.source_tabname]["fields"] if name in paste_field_names],
|
|
||||||
outputs=[field for field, name in fields if name in paste_field_names],
|
|
||||||
)
|
|
||||||
|
|
||||||
binding.paste_button.click(
|
|
||||||
fn=None,
|
|
||||||
_js=f"switch_to_{binding.tabname}",
|
|
||||||
inputs=None,
|
|
||||||
outputs=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def send_image_and_dimensions(x):
|
def send_image_and_dimensions(x):
|
||||||
|
@ -164,6 +122,49 @@ def send_image_and_dimensions(x):
|
||||||
return img, w, h
|
return img, w, h
|
||||||
|
|
||||||
|
|
||||||
|
def run_bind():
|
||||||
|
for buttons, source_image_component, send_generate_info in bind_list:
|
||||||
|
for tab in buttons:
|
||||||
|
button = buttons[tab]
|
||||||
|
destination_image_component = paste_fields[tab]["init_img"]
|
||||||
|
fields = paste_fields[tab]["fields"]
|
||||||
|
|
||||||
|
destination_width_component = next(iter([field for field, name in fields if name == "Size-1"] if fields else []), None)
|
||||||
|
destination_height_component = next(iter([field for field, name in fields if name == "Size-2"] if fields else []), None)
|
||||||
|
|
||||||
|
if source_image_component and destination_image_component:
|
||||||
|
if isinstance(source_image_component, gr.Gallery):
|
||||||
|
func = send_image_and_dimensions if destination_width_component else image_from_url_text
|
||||||
|
jsfunc = "extract_image_from_gallery"
|
||||||
|
else:
|
||||||
|
func = send_image_and_dimensions if destination_width_component else lambda x: x
|
||||||
|
jsfunc = None
|
||||||
|
|
||||||
|
button.click(
|
||||||
|
fn=func,
|
||||||
|
_js=jsfunc,
|
||||||
|
inputs=[source_image_component],
|
||||||
|
outputs=[destination_image_component, destination_width_component, destination_height_component] if destination_width_component else [destination_image_component],
|
||||||
|
)
|
||||||
|
|
||||||
|
if send_generate_info and fields is not None:
|
||||||
|
if send_generate_info in paste_fields:
|
||||||
|
paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration'] + (["Seed"] if shared.opts.send_seed else [])
|
||||||
|
button.click(
|
||||||
|
fn=lambda *x: x,
|
||||||
|
inputs=[field for field, name in paste_fields[send_generate_info]["fields"] if name in paste_field_names],
|
||||||
|
outputs=[field for field, name in fields if name in paste_field_names],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
connect_paste(button, fields, send_generate_info)
|
||||||
|
|
||||||
|
button.click(
|
||||||
|
fn=None,
|
||||||
|
_js=f"switch_to_{tab}",
|
||||||
|
inputs=None,
|
||||||
|
outputs=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def find_hypernetwork_key(hypernet_name, hypernet_hash=None):
|
def find_hypernetwork_key(hypernet_name, hypernet_hash=None):
|
||||||
"""Determines the config parameter name to use for the hypernet based on the parameters in the infotext.
|
"""Determines the config parameter name to use for the hypernet based on the parameters in the infotext.
|
||||||
|
@ -285,50 +286,7 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
settings_map = {}
|
def connect_paste(button, paste_fields, input_comp, jsfunc=None):
|
||||||
|
|
||||||
infotext_to_setting_name_mapping = [
|
|
||||||
('Clip skip', 'CLIP_stop_at_last_layers', ),
|
|
||||||
('Conditional mask weight', 'inpainting_mask_weight'),
|
|
||||||
('Model hash', 'sd_model_checkpoint'),
|
|
||||||
('ENSD', 'eta_noise_seed_delta'),
|
|
||||||
('Noise multiplier', 'initial_noise_multiplier'),
|
|
||||||
('Eta', 'eta_ancestral'),
|
|
||||||
('Eta DDIM', 'eta_ddim'),
|
|
||||||
('Discard penultimate sigma', 'always_discard_next_to_last_sigma')
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def create_override_settings_dict(text_pairs):
|
|
||||||
"""creates processing's override_settings parameters from gradio's multiselect
|
|
||||||
|
|
||||||
Example input:
|
|
||||||
['Clip skip: 2', 'Model hash: e6e99610c4', 'ENSD: 31337']
|
|
||||||
|
|
||||||
Example output:
|
|
||||||
{'CLIP_stop_at_last_layers': 2, 'sd_model_checkpoint': 'e6e99610c4', 'eta_noise_seed_delta': 31337}
|
|
||||||
"""
|
|
||||||
|
|
||||||
res = {}
|
|
||||||
|
|
||||||
params = {}
|
|
||||||
for pair in text_pairs:
|
|
||||||
k, v = pair.split(":", maxsplit=1)
|
|
||||||
|
|
||||||
params[k] = v.strip()
|
|
||||||
|
|
||||||
for param_name, setting_name in infotext_to_setting_name_mapping:
|
|
||||||
value = params.get(param_name, None)
|
|
||||||
|
|
||||||
if value is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
res[setting_name] = shared.opts.cast_value(setting_name, value)
|
|
||||||
|
|
||||||
return res
|
|
||||||
|
|
||||||
|
|
||||||
def connect_paste(button, paste_fields, input_comp, override_settings_component, tabname):
|
|
||||||
def paste_func(prompt):
|
def paste_func(prompt):
|
||||||
if not prompt and not shared.cmd_opts.hide_ui_dir_config:
|
if not prompt and not shared.cmd_opts.hide_ui_dir_config:
|
||||||
filename = os.path.join(data_path, "params.txt")
|
filename = os.path.join(data_path, "params.txt")
|
||||||
|
@ -365,35 +323,9 @@ def connect_paste(button, paste_fields, input_comp, override_settings_component,
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
if override_settings_component is not None:
|
|
||||||
def paste_settings(params):
|
|
||||||
vals = {}
|
|
||||||
|
|
||||||
for param_name, setting_name in infotext_to_setting_name_mapping:
|
|
||||||
v = params.get(param_name, None)
|
|
||||||
if v is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if setting_name == "sd_model_checkpoint" and shared.opts.disable_weights_auto_swap:
|
|
||||||
continue
|
|
||||||
|
|
||||||
v = shared.opts.cast_value(setting_name, v)
|
|
||||||
current_value = getattr(shared.opts, setting_name, None)
|
|
||||||
|
|
||||||
if v == current_value:
|
|
||||||
continue
|
|
||||||
|
|
||||||
vals[param_name] = v
|
|
||||||
|
|
||||||
vals_pairs = [f"{k}: {v}" for k, v in vals.items()]
|
|
||||||
|
|
||||||
return gr.Dropdown.update(value=vals_pairs, choices=vals_pairs, visible=len(vals_pairs) > 0)
|
|
||||||
|
|
||||||
paste_fields = paste_fields + [(override_settings_component, paste_settings)]
|
|
||||||
|
|
||||||
button.click(
|
button.click(
|
||||||
fn=paste_func,
|
fn=paste_func,
|
||||||
_js=f"recalculate_prompts_{tabname}",
|
_js=jsfunc,
|
||||||
inputs=[input_comp],
|
inputs=[input_comp],
|
||||||
outputs=[x[0] for x in paste_fields],
|
outputs=[x[0] for x in paste_fields],
|
||||||
)
|
)
|
||||||
|
|
|
@ -4,7 +4,6 @@ import os.path
|
||||||
|
|
||||||
import filelock
|
import filelock
|
||||||
|
|
||||||
from modules import shared
|
|
||||||
from modules.paths import data_path
|
from modules.paths import data_path
|
||||||
|
|
||||||
|
|
||||||
|
@ -69,9 +68,6 @@ def sha256(filename, title):
|
||||||
if sha256_value is not None:
|
if sha256_value is not None:
|
||||||
return sha256_value
|
return sha256_value
|
||||||
|
|
||||||
if shared.cmd_opts.no_hashing:
|
|
||||||
return None
|
|
||||||
|
|
||||||
print(f"Calculating sha256 for {filename}: ", end='')
|
print(f"Calculating sha256 for {filename}: ", end='')
|
||||||
sha256_value = calculate_sha256(filename)
|
sha256_value = calculate_sha256(filename)
|
||||||
print(f"{sha256_value}")
|
print(f"{sha256_value}")
|
||||||
|
|
|
@ -307,7 +307,7 @@ class Hypernetwork:
|
||||||
def shorthash(self):
|
def shorthash(self):
|
||||||
sha256 = hashes.sha256(self.filename, f'hypernet/{self.name}')
|
sha256 = hashes.sha256(self.filename, f'hypernet/{self.name}')
|
||||||
|
|
||||||
return sha256[0:10] if sha256 else None
|
return sha256[0:10]
|
||||||
|
|
||||||
|
|
||||||
def list_hypernetworks(path):
|
def list_hypernetworks(path):
|
||||||
|
|
|
@ -16,7 +16,6 @@ from PIL import Image, ImageFont, ImageDraw, PngImagePlugin
|
||||||
from fonts.ttf import Roboto
|
from fonts.ttf import Roboto
|
||||||
import string
|
import string
|
||||||
import json
|
import json
|
||||||
import hashlib
|
|
||||||
|
|
||||||
from modules import sd_samplers, shared, script_callbacks
|
from modules import sd_samplers, shared, script_callbacks
|
||||||
from modules.shared import opts, cmd_opts
|
from modules.shared import opts, cmd_opts
|
||||||
|
@ -37,8 +36,6 @@ def image_grid(imgs, batch_size=1, rows=None):
|
||||||
else:
|
else:
|
||||||
rows = math.sqrt(len(imgs))
|
rows = math.sqrt(len(imgs))
|
||||||
rows = round(rows)
|
rows = round(rows)
|
||||||
if rows > len(imgs):
|
|
||||||
rows = len(imgs)
|
|
||||||
|
|
||||||
cols = math.ceil(len(imgs) / rows)
|
cols = math.ceil(len(imgs) / rows)
|
||||||
|
|
||||||
|
@ -131,7 +128,7 @@ class GridAnnotation:
|
||||||
self.size = None
|
self.size = None
|
||||||
|
|
||||||
|
|
||||||
def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0):
|
def draw_grid_annotations(im, width, height, hor_texts, ver_texts):
|
||||||
def wrap(drawing, text, font, line_length):
|
def wrap(drawing, text, font, line_length):
|
||||||
lines = ['']
|
lines = ['']
|
||||||
for word in text.split():
|
for word in text.split():
|
||||||
|
@ -195,35 +192,32 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0):
|
||||||
line.allowed_width = allowed_width
|
line.allowed_width = allowed_width
|
||||||
|
|
||||||
hor_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing for lines in hor_texts]
|
hor_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing for lines in hor_texts]
|
||||||
ver_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing * len(lines) for lines in ver_texts]
|
ver_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing * len(lines) for lines in
|
||||||
|
ver_texts]
|
||||||
|
|
||||||
pad_top = 0 if sum(hor_text_heights) == 0 else max(hor_text_heights) + line_spacing * 2
|
pad_top = 0 if sum(hor_text_heights) == 0 else max(hor_text_heights) + line_spacing * 2
|
||||||
|
|
||||||
result = Image.new("RGB", (im.width + pad_left + margin * (cols-1), im.height + pad_top + margin * (rows-1)), "white")
|
result = Image.new("RGB", (im.width + pad_left, im.height + pad_top), "white")
|
||||||
|
result.paste(im, (pad_left, pad_top))
|
||||||
for row in range(rows):
|
|
||||||
for col in range(cols):
|
|
||||||
cell = im.crop((width * col, height * row, width * (col+1), height * (row+1)))
|
|
||||||
result.paste(cell, (pad_left + (width + margin) * col, pad_top + (height + margin) * row))
|
|
||||||
|
|
||||||
d = ImageDraw.Draw(result)
|
d = ImageDraw.Draw(result)
|
||||||
|
|
||||||
for col in range(cols):
|
for col in range(cols):
|
||||||
x = pad_left + (width + margin) * col + width / 2
|
x = pad_left + width * col + width / 2
|
||||||
y = pad_top / 2 - hor_text_heights[col] / 2
|
y = pad_top / 2 - hor_text_heights[col] / 2
|
||||||
|
|
||||||
draw_texts(d, x, y, hor_texts[col], fnt, fontsize)
|
draw_texts(d, x, y, hor_texts[col], fnt, fontsize)
|
||||||
|
|
||||||
for row in range(rows):
|
for row in range(rows):
|
||||||
x = pad_left / 2
|
x = pad_left / 2
|
||||||
y = pad_top + (height + margin) * row + height / 2 - ver_text_heights[row] / 2
|
y = pad_top + height * row + height / 2 - ver_text_heights[row] / 2
|
||||||
|
|
||||||
draw_texts(d, x, y, ver_texts[row], fnt, fontsize)
|
draw_texts(d, x, y, ver_texts[row], fnt, fontsize)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def draw_prompt_matrix(im, width, height, all_prompts, margin=0):
|
def draw_prompt_matrix(im, width, height, all_prompts):
|
||||||
prompts = all_prompts[1:]
|
prompts = all_prompts[1:]
|
||||||
boundary = math.ceil(len(prompts) / 2)
|
boundary = math.ceil(len(prompts) / 2)
|
||||||
|
|
||||||
|
@ -233,7 +227,7 @@ def draw_prompt_matrix(im, width, height, all_prompts, margin=0):
|
||||||
hor_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_horiz)] for pos in range(1 << len(prompts_horiz))]
|
hor_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_horiz)] for pos in range(1 << len(prompts_horiz))]
|
||||||
ver_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_vert)] for pos in range(1 << len(prompts_vert))]
|
ver_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_vert)] for pos in range(1 << len(prompts_vert))]
|
||||||
|
|
||||||
return draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin)
|
return draw_grid_annotations(im, width, height, hor_texts, ver_texts)
|
||||||
|
|
||||||
|
|
||||||
def resize_image(resize_mode, im, width, height, upscaler_name=None):
|
def resize_image(resize_mode, im, width, height, upscaler_name=None):
|
||||||
|
@ -344,7 +338,6 @@ class FilenameGenerator:
|
||||||
'date': lambda self: datetime.datetime.now().strftime('%Y-%m-%d'),
|
'date': lambda self: datetime.datetime.now().strftime('%Y-%m-%d'),
|
||||||
'datetime': lambda self, *args: self.datetime(*args), # accepts formats: [datetime], [datetime<Format>], [datetime<Format><Time Zone>]
|
'datetime': lambda self, *args: self.datetime(*args), # accepts formats: [datetime], [datetime<Format>], [datetime<Format><Time Zone>]
|
||||||
'job_timestamp': lambda self: getattr(self.p, "job_timestamp", shared.state.job_timestamp),
|
'job_timestamp': lambda self: getattr(self.p, "job_timestamp", shared.state.job_timestamp),
|
||||||
'prompt_hash': lambda self: hashlib.sha256(self.prompt.encode()).hexdigest()[0:8],
|
|
||||||
'prompt': lambda self: sanitize_filename_part(self.prompt),
|
'prompt': lambda self: sanitize_filename_part(self.prompt),
|
||||||
'prompt_no_styles': lambda self: self.prompt_no_style(),
|
'prompt_no_styles': lambda self: self.prompt_no_style(),
|
||||||
'prompt_spaces': lambda self: sanitize_filename_part(self.prompt, replace_spaces=False),
|
'prompt_spaces': lambda self: sanitize_filename_part(self.prompt, replace_spaces=False),
|
||||||
|
|
|
@ -7,7 +7,6 @@ import numpy as np
|
||||||
from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops
|
from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops
|
||||||
|
|
||||||
from modules import devices, sd_samplers
|
from modules import devices, sd_samplers
|
||||||
from modules.generation_parameters_copypaste import create_override_settings_dict
|
|
||||||
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
|
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
|
||||||
from modules.shared import opts, state
|
from modules.shared import opts, state
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
|
@ -76,9 +75,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args):
|
||||||
processed_image.save(os.path.join(output_dir, filename))
|
processed_image.save(os.path.join(output_dir, filename))
|
||||||
|
|
||||||
|
|
||||||
def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, *args):
|
def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, *args):
|
||||||
override_settings = create_override_settings_dict(override_settings_texts)
|
|
||||||
|
|
||||||
is_batch = mode == 5
|
is_batch = mode == 5
|
||||||
|
|
||||||
if mode == 0: # img2img
|
if mode == 0: # img2img
|
||||||
|
@ -142,11 +139,9 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
|
||||||
inpainting_fill=inpainting_fill,
|
inpainting_fill=inpainting_fill,
|
||||||
resize_mode=resize_mode,
|
resize_mode=resize_mode,
|
||||||
denoising_strength=denoising_strength,
|
denoising_strength=denoising_strength,
|
||||||
image_cfg_scale=image_cfg_scale,
|
|
||||||
inpaint_full_res=inpaint_full_res,
|
inpaint_full_res=inpaint_full_res,
|
||||||
inpaint_full_res_padding=inpaint_full_res_padding,
|
inpaint_full_res_padding=inpaint_full_res_padding,
|
||||||
inpainting_mask_invert=inpainting_mask_invert,
|
inpainting_mask_invert=inpainting_mask_invert,
|
||||||
override_settings=override_settings,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
p.scripts = modules.scripts.scripts_txt2img
|
p.scripts = modules.scripts.scripts_txt2img
|
||||||
|
|
|
@ -1,53 +0,0 @@
|
||||||
import torch
|
|
||||||
from modules import paths
|
|
||||||
from modules.sd_hijack_utils import CondFunc
|
|
||||||
from packaging import version
|
|
||||||
|
|
||||||
|
|
||||||
# has_mps is only available in nightly pytorch (for now) and macOS 12.3+.
|
|
||||||
# check `getattr` and try it for compatibility
|
|
||||||
def check_for_mps() -> bool:
|
|
||||||
if not getattr(torch, 'has_mps', False):
|
|
||||||
return False
|
|
||||||
try:
|
|
||||||
torch.zeros(1).to(torch.device("mps"))
|
|
||||||
return True
|
|
||||||
except Exception:
|
|
||||||
return False
|
|
||||||
has_mps = check_for_mps()
|
|
||||||
|
|
||||||
|
|
||||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/89784
|
|
||||||
def cumsum_fix(input, cumsum_func, *args, **kwargs):
|
|
||||||
if input.device.type == 'mps':
|
|
||||||
output_dtype = kwargs.get('dtype', input.dtype)
|
|
||||||
if output_dtype == torch.int64:
|
|
||||||
return cumsum_func(input.cpu(), *args, **kwargs).to(input.device)
|
|
||||||
elif cumsum_needs_bool_fix and output_dtype == torch.bool or cumsum_needs_int_fix and (output_dtype == torch.int8 or output_dtype == torch.int16):
|
|
||||||
return cumsum_func(input.to(torch.int32), *args, **kwargs).to(torch.int64)
|
|
||||||
return cumsum_func(input, *args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
if has_mps:
|
|
||||||
# MPS fix for randn in torchsde
|
|
||||||
CondFunc('torchsde._brownian.brownian_interval._randn', lambda _, size, dtype, device, seed: torch.randn(size, dtype=dtype, device=torch.device("cpu"), generator=torch.Generator(torch.device("cpu")).manual_seed(int(seed))).to(device), lambda _, size, dtype, device, seed: device.type == 'mps')
|
|
||||||
|
|
||||||
if version.parse(torch.__version__) < version.parse("1.13"):
|
|
||||||
# PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working
|
|
||||||
|
|
||||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/79383
|
|
||||||
CondFunc('torch.Tensor.to', lambda orig_func, self, *args, **kwargs: orig_func(self.contiguous(), *args, **kwargs),
|
|
||||||
lambda _, self, *args, **kwargs: self.device.type != 'mps' and (args and isinstance(args[0], torch.device) and args[0].type == 'mps' or isinstance(kwargs.get('device'), torch.device) and kwargs['device'].type == 'mps'))
|
|
||||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/80800
|
|
||||||
CondFunc('torch.nn.functional.layer_norm', lambda orig_func, *args, **kwargs: orig_func(*([args[0].contiguous()] + list(args[1:])), **kwargs),
|
|
||||||
lambda _, *args, **kwargs: args and isinstance(args[0], torch.Tensor) and args[0].device.type == 'mps')
|
|
||||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/90532
|
|
||||||
CondFunc('torch.Tensor.numpy', lambda orig_func, self, *args, **kwargs: orig_func(self.detach(), *args, **kwargs), lambda _, self, *args, **kwargs: self.requires_grad)
|
|
||||||
elif version.parse(torch.__version__) > version.parse("1.13.1"):
|
|
||||||
cumsum_needs_int_fix = not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.ShortTensor([1,1]).to(torch.device("mps")).cumsum(0))
|
|
||||||
cumsum_needs_bool_fix = not torch.BoolTensor([True,True]).to(device=torch.device("mps"), dtype=torch.int64).equal(torch.BoolTensor([True,False]).to(torch.device("mps")).cumsum(0))
|
|
||||||
cumsum_fix_func = lambda orig_func, input, *args, **kwargs: cumsum_fix(input, orig_func, *args, **kwargs)
|
|
||||||
CondFunc('torch.cumsum', cumsum_fix_func, None)
|
|
||||||
CondFunc('torch.Tensor.cumsum', cumsum_fix_func, None)
|
|
||||||
CondFunc('torch.narrow', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).clone(), None)
|
|
||||||
|
|
|
@ -45,9 +45,6 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
|
||||||
full_path = file
|
full_path = file
|
||||||
if os.path.isdir(full_path):
|
if os.path.isdir(full_path):
|
||||||
continue
|
continue
|
||||||
if os.path.islink(full_path) and not os.path.exists(full_path):
|
|
||||||
print(f"Skipping broken symlink: {full_path}")
|
|
||||||
continue
|
|
||||||
if ext_blacklist is not None and any([full_path.endswith(x) for x in ext_blacklist]):
|
if ext_blacklist is not None and any([full_path.endswith(x) for x in ext_blacklist]):
|
||||||
continue
|
continue
|
||||||
if len(ext_filter) != 0:
|
if len(ext_filter) != 0:
|
||||||
|
|
|
@ -186,7 +186,7 @@ class StableDiffusionProcessing:
|
||||||
return conditioning
|
return conditioning
|
||||||
|
|
||||||
def edit_image_conditioning(self, source_image):
|
def edit_image_conditioning(self, source_image):
|
||||||
conditioning_image = self.sd_model.encode_first_stage(source_image).mode()
|
conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image))
|
||||||
|
|
||||||
return conditioning_image
|
return conditioning_image
|
||||||
|
|
||||||
|
@ -268,7 +268,6 @@ class Processed:
|
||||||
self.height = p.height
|
self.height = p.height
|
||||||
self.sampler_name = p.sampler_name
|
self.sampler_name = p.sampler_name
|
||||||
self.cfg_scale = p.cfg_scale
|
self.cfg_scale = p.cfg_scale
|
||||||
self.image_cfg_scale = getattr(p, 'image_cfg_scale', None)
|
|
||||||
self.steps = p.steps
|
self.steps = p.steps
|
||||||
self.batch_size = p.batch_size
|
self.batch_size = p.batch_size
|
||||||
self.restore_faces = p.restore_faces
|
self.restore_faces = p.restore_faces
|
||||||
|
@ -446,17 +445,19 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
|
||||||
"Steps": p.steps,
|
"Steps": p.steps,
|
||||||
"Sampler": p.sampler_name,
|
"Sampler": p.sampler_name,
|
||||||
"CFG scale": p.cfg_scale,
|
"CFG scale": p.cfg_scale,
|
||||||
"Image CFG scale": getattr(p, 'image_cfg_scale', None),
|
|
||||||
"Seed": all_seeds[index],
|
"Seed": all_seeds[index],
|
||||||
"Face restoration": (opts.face_restoration_model if p.restore_faces else None),
|
"Face restoration": (opts.face_restoration_model if p.restore_faces else None),
|
||||||
"Size": f"{p.width}x{p.height}",
|
"Size": f"{p.width}x{p.height}",
|
||||||
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
|
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
|
||||||
"Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
|
"Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
|
||||||
|
"Batch size": (None if p.batch_size < 2 else p.batch_size),
|
||||||
|
"Batch pos": (None if p.batch_size < 2 else position_in_batch),
|
||||||
"Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
|
"Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
|
||||||
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
|
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
|
||||||
"Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
|
"Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
|
||||||
"Denoising strength": getattr(p, 'denoising_strength', None),
|
"Denoising strength": getattr(p, 'denoising_strength', None),
|
||||||
"Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None,
|
"Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None,
|
||||||
|
"Eta": (None if p.sampler is None or p.sampler.eta == p.sampler.default_eta else p.sampler.eta),
|
||||||
"Clip skip": None if clip_skip <= 1 else clip_skip,
|
"Clip skip": None if clip_skip <= 1 else clip_skip,
|
||||||
"ENSD": None if opts.eta_noise_seed_delta == 0 else opts.eta_noise_seed_delta,
|
"ENSD": None if opts.eta_noise_seed_delta == 0 else opts.eta_noise_seed_delta,
|
||||||
}
|
}
|
||||||
|
@ -903,13 +904,12 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||||
class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
||||||
sampler = None
|
sampler = None
|
||||||
|
|
||||||
def __init__(self, init_images: list = None, resize_mode: int = 0, denoising_strength: float = 0.75, image_cfg_scale: float = None, mask: Any = None, mask_blur: int = 4, inpainting_fill: int = 0, inpaint_full_res: bool = True, inpaint_full_res_padding: int = 0, inpainting_mask_invert: int = 0, initial_noise_multiplier: float = None, **kwargs):
|
def __init__(self, init_images: list = None, resize_mode: int = 0, denoising_strength: float = 0.75, mask: Any = None, mask_blur: int = 4, inpainting_fill: int = 0, inpaint_full_res: bool = True, inpaint_full_res_padding: int = 0, inpainting_mask_invert: int = 0, initial_noise_multiplier: float = None, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
self.init_images = init_images
|
self.init_images = init_images
|
||||||
self.resize_mode: int = resize_mode
|
self.resize_mode: int = resize_mode
|
||||||
self.denoising_strength: float = denoising_strength
|
self.denoising_strength: float = denoising_strength
|
||||||
self.image_cfg_scale: float = image_cfg_scale if shared.sd_model.cond_stage_key == "edit" else None
|
|
||||||
self.init_latent = None
|
self.init_latent = None
|
||||||
self.image_mask = mask
|
self.image_mask = mask
|
||||||
self.latent_mask = None
|
self.latent_mask = None
|
||||||
|
|
|
@ -20,9 +20,8 @@ class DisableInitialization:
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, disable_clip=True):
|
def __init__(self):
|
||||||
self.replaced = []
|
self.replaced = []
|
||||||
self.disable_clip = disable_clip
|
|
||||||
|
|
||||||
def replace(self, obj, field, func):
|
def replace(self, obj, field, func):
|
||||||
original = getattr(obj, field, None)
|
original = getattr(obj, field, None)
|
||||||
|
@ -76,14 +75,12 @@ class DisableInitialization:
|
||||||
self.replace(torch.nn.init, 'kaiming_uniform_', do_nothing)
|
self.replace(torch.nn.init, 'kaiming_uniform_', do_nothing)
|
||||||
self.replace(torch.nn.init, '_no_grad_normal_', do_nothing)
|
self.replace(torch.nn.init, '_no_grad_normal_', do_nothing)
|
||||||
self.replace(torch.nn.init, '_no_grad_uniform_', do_nothing)
|
self.replace(torch.nn.init, '_no_grad_uniform_', do_nothing)
|
||||||
|
self.create_model_and_transforms = self.replace(open_clip, 'create_model_and_transforms', create_model_and_transforms_without_pretrained)
|
||||||
if self.disable_clip:
|
self.CLIPTextModel_from_pretrained = self.replace(ldm.modules.encoders.modules.CLIPTextModel, 'from_pretrained', CLIPTextModel_from_pretrained)
|
||||||
self.create_model_and_transforms = self.replace(open_clip, 'create_model_and_transforms', create_model_and_transforms_without_pretrained)
|
self.transformers_modeling_utils_load_pretrained_model = self.replace(transformers.modeling_utils.PreTrainedModel, '_load_pretrained_model', transformers_modeling_utils_load_pretrained_model)
|
||||||
self.CLIPTextModel_from_pretrained = self.replace(ldm.modules.encoders.modules.CLIPTextModel, 'from_pretrained', CLIPTextModel_from_pretrained)
|
self.transformers_tokenization_utils_base_cached_file = self.replace(transformers.tokenization_utils_base, 'cached_file', transformers_tokenization_utils_base_cached_file)
|
||||||
self.transformers_modeling_utils_load_pretrained_model = self.replace(transformers.modeling_utils.PreTrainedModel, '_load_pretrained_model', transformers_modeling_utils_load_pretrained_model)
|
self.transformers_configuration_utils_cached_file = self.replace(transformers.configuration_utils, 'cached_file', transformers_configuration_utils_cached_file)
|
||||||
self.transformers_tokenization_utils_base_cached_file = self.replace(transformers.tokenization_utils_base, 'cached_file', transformers_tokenization_utils_base_cached_file)
|
self.transformers_utils_hub_get_from_cache = self.replace(transformers.utils.hub, 'get_from_cache', transformers_utils_hub_get_from_cache)
|
||||||
self.transformers_configuration_utils_cached_file = self.replace(transformers.configuration_utils, 'cached_file', transformers_configuration_utils_cached_file)
|
|
||||||
self.transformers_utils_hub_get_from_cache = self.replace(transformers.utils.hub, 'get_from_cache', transformers_utils_hub_get_from_cache)
|
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
for obj, field, original in self.replaced:
|
for obj, field, original in self.replaced:
|
||||||
|
|
|
@ -41,7 +41,6 @@ class CheckpointInfo:
|
||||||
name = name[1:]
|
name = name[1:]
|
||||||
|
|
||||||
self.name = name
|
self.name = name
|
||||||
self.name_for_extra = os.path.splitext(os.path.basename(filename))[0]
|
|
||||||
self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]
|
self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]
|
||||||
self.hash = model_hash(filename)
|
self.hash = model_hash(filename)
|
||||||
|
|
||||||
|
@ -59,17 +58,13 @@ class CheckpointInfo:
|
||||||
|
|
||||||
def calculate_shorthash(self):
|
def calculate_shorthash(self):
|
||||||
self.sha256 = hashes.sha256(self.filename, "checkpoint/" + self.name)
|
self.sha256 = hashes.sha256(self.filename, "checkpoint/" + self.name)
|
||||||
if self.sha256 is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
self.shorthash = self.sha256[0:10]
|
self.shorthash = self.sha256[0:10]
|
||||||
|
|
||||||
if self.shorthash not in self.ids:
|
if self.shorthash not in self.ids:
|
||||||
self.ids += [self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]']
|
self.ids += [self.shorthash, self.sha256]
|
||||||
|
self.register()
|
||||||
|
|
||||||
checkpoints_list.pop(self.title)
|
|
||||||
self.title = f'{self.name} [{self.shorthash}]'
|
self.title = f'{self.name} [{self.shorthash}]'
|
||||||
self.register()
|
|
||||||
|
|
||||||
return self.shorthash
|
return self.shorthash
|
||||||
|
|
||||||
|
@ -162,7 +157,7 @@ def select_checkpoint():
|
||||||
print(f" - directory {model_path}", file=sys.stderr)
|
print(f" - directory {model_path}", file=sys.stderr)
|
||||||
if shared.cmd_opts.ckpt_dir is not None:
|
if shared.cmd_opts.ckpt_dir is not None:
|
||||||
print(f" - directory {os.path.abspath(shared.cmd_opts.ckpt_dir)}", file=sys.stderr)
|
print(f" - directory {os.path.abspath(shared.cmd_opts.ckpt_dir)}", file=sys.stderr)
|
||||||
print("Can't run without a checkpoint. Find and place a .ckpt or .safetensors file into any of those locations. The program will exit.", file=sys.stderr)
|
print("Can't run without a checkpoint. Find and place a .ckpt file into any of those locations. The program will exit.", file=sys.stderr)
|
||||||
exit(1)
|
exit(1)
|
||||||
|
|
||||||
checkpoint_info = next(iter(checkpoints_list.values()))
|
checkpoint_info = next(iter(checkpoints_list.values()))
|
||||||
|
@ -207,7 +202,7 @@ def get_state_dict_from_checkpoint(pl_sd):
|
||||||
def read_state_dict(checkpoint_file, print_global_state=False, map_location=None):
|
def read_state_dict(checkpoint_file, print_global_state=False, map_location=None):
|
||||||
_, extension = os.path.splitext(checkpoint_file)
|
_, extension = os.path.splitext(checkpoint_file)
|
||||||
if extension.lower() == ".safetensors":
|
if extension.lower() == ".safetensors":
|
||||||
device = map_location or shared.weight_load_location or devices.get_optimal_device()
|
device = map_location or shared.weight_load_location or devices.get_optimal_device_name()
|
||||||
pl_sd = safetensors.torch.load_file(checkpoint_file, device=device)
|
pl_sd = safetensors.torch.load_file(checkpoint_file, device=device)
|
||||||
else:
|
else:
|
||||||
pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location)
|
pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location)
|
||||||
|
@ -354,9 +349,6 @@ def repair_config(sd_config):
|
||||||
sd_config.model.params.unet_config.params.use_fp16 = True
|
sd_config.model.params.unet_config.params.use_fp16 = True
|
||||||
|
|
||||||
|
|
||||||
sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight'
|
|
||||||
sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight'
|
|
||||||
|
|
||||||
def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_to_load_state_dict=None):
|
def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_to_load_state_dict=None):
|
||||||
from modules import lowvram, sd_hijack
|
from modules import lowvram, sd_hijack
|
||||||
checkpoint_info = checkpoint_info or select_checkpoint()
|
checkpoint_info = checkpoint_info or select_checkpoint()
|
||||||
|
@ -377,7 +369,6 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_
|
||||||
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
|
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
|
||||||
|
|
||||||
checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
|
checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
|
||||||
clip_is_included_into_sd = sd1_clip_weight in state_dict or sd2_clip_weight in state_dict
|
|
||||||
|
|
||||||
timer.record("find config")
|
timer.record("find config")
|
||||||
|
|
||||||
|
@ -390,7 +381,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_
|
||||||
|
|
||||||
sd_model = None
|
sd_model = None
|
||||||
try:
|
try:
|
||||||
with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd):
|
with sd_disable_initialization.DisableInitialization():
|
||||||
sd_model = instantiate_from_config(sd_config.model)
|
sd_model = instantiate_from_config(sd_config.model)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -1,11 +1,53 @@
|
||||||
from modules import sd_samplers_compvis, sd_samplers_kdiffusion, shared
|
from collections import namedtuple, deque
|
||||||
|
import numpy as np
|
||||||
|
from math import floor
|
||||||
|
import torch
|
||||||
|
import tqdm
|
||||||
|
from PIL import Image
|
||||||
|
import inspect
|
||||||
|
import k_diffusion.sampling
|
||||||
|
import torchsde._brownian.brownian_interval
|
||||||
|
import ldm.models.diffusion.ddim
|
||||||
|
import ldm.models.diffusion.plms
|
||||||
|
from modules import prompt_parser, devices, processing, images, sd_vae_approx
|
||||||
|
|
||||||
# imports for functions that previously were here and are used by other modules
|
from modules.shared import opts, cmd_opts, state
|
||||||
from modules.sd_samplers_common import samples_to_image_grid, sample_to_image
|
import modules.shared as shared
|
||||||
|
from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback
|
||||||
|
|
||||||
|
|
||||||
|
SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
|
||||||
|
|
||||||
|
samplers_k_diffusion = [
|
||||||
|
('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {}),
|
||||||
|
('Euler', 'sample_euler', ['k_euler'], {}),
|
||||||
|
('LMS', 'sample_lms', ['k_lms'], {}),
|
||||||
|
('Heun', 'sample_heun', ['k_heun'], {}),
|
||||||
|
('DPM2', 'sample_dpm_2', ['k_dpm_2'], {'discard_next_to_last_sigma': True}),
|
||||||
|
('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'discard_next_to_last_sigma': True}),
|
||||||
|
('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {}),
|
||||||
|
('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}),
|
||||||
|
('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {}),
|
||||||
|
('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {}),
|
||||||
|
('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {}),
|
||||||
|
('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}),
|
||||||
|
('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True}),
|
||||||
|
('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True}),
|
||||||
|
('DPM++ 2S a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras'}),
|
||||||
|
('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}),
|
||||||
|
('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras'}),
|
||||||
|
]
|
||||||
|
|
||||||
|
samplers_data_k_diffusion = [
|
||||||
|
SamplerData(label, lambda model, funcname=funcname: KDiffusionSampler(funcname, model), aliases, options)
|
||||||
|
for label, funcname, aliases, options in samplers_k_diffusion
|
||||||
|
if hasattr(k_diffusion.sampling, funcname)
|
||||||
|
]
|
||||||
|
|
||||||
all_samplers = [
|
all_samplers = [
|
||||||
*sd_samplers_kdiffusion.samplers_data_k_diffusion,
|
*samplers_data_k_diffusion,
|
||||||
*sd_samplers_compvis.samplers_data_compvis,
|
SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {}),
|
||||||
|
SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {}),
|
||||||
]
|
]
|
||||||
all_samplers_map = {x.name: x for x in all_samplers}
|
all_samplers_map = {x.name: x for x in all_samplers}
|
||||||
|
|
||||||
|
@ -31,8 +73,8 @@ def create_sampler(name, model):
|
||||||
def set_samplers():
|
def set_samplers():
|
||||||
global samplers, samplers_for_img2img
|
global samplers, samplers_for_img2img
|
||||||
|
|
||||||
hidden = set(shared.opts.hide_samplers)
|
hidden = set(opts.hide_samplers)
|
||||||
hidden_img2img = set(shared.opts.hide_samplers + ['PLMS'])
|
hidden_img2img = set(opts.hide_samplers + ['PLMS'])
|
||||||
|
|
||||||
samplers = [x for x in all_samplers if x.name not in hidden]
|
samplers = [x for x in all_samplers if x.name not in hidden]
|
||||||
samplers_for_img2img = [x for x in all_samplers if x.name not in hidden_img2img]
|
samplers_for_img2img = [x for x in all_samplers if x.name not in hidden_img2img]
|
||||||
|
@ -45,3 +87,466 @@ def set_samplers():
|
||||||
|
|
||||||
|
|
||||||
set_samplers()
|
set_samplers()
|
||||||
|
|
||||||
|
sampler_extra_params = {
|
||||||
|
'sample_euler': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
|
||||||
|
'sample_heun': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
|
||||||
|
'sample_dpm_2': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def setup_img2img_steps(p, steps=None):
|
||||||
|
if opts.img2img_fix_steps or steps is not None:
|
||||||
|
requested_steps = (steps or p.steps)
|
||||||
|
steps = int(requested_steps / min(p.denoising_strength, 0.999)) if p.denoising_strength > 0 else 0
|
||||||
|
t_enc = requested_steps - 1
|
||||||
|
else:
|
||||||
|
steps = p.steps
|
||||||
|
t_enc = int(min(p.denoising_strength, 0.999) * steps)
|
||||||
|
|
||||||
|
return steps, t_enc
|
||||||
|
|
||||||
|
|
||||||
|
approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2}
|
||||||
|
|
||||||
|
|
||||||
|
def single_sample_to_image(sample, approximation=None):
|
||||||
|
if approximation is None:
|
||||||
|
approximation = approximation_indexes.get(opts.show_progress_type, 0)
|
||||||
|
|
||||||
|
if approximation == 2:
|
||||||
|
x_sample = sd_vae_approx.cheap_approximation(sample)
|
||||||
|
elif approximation == 1:
|
||||||
|
x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach()
|
||||||
|
else:
|
||||||
|
x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0]
|
||||||
|
|
||||||
|
x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
|
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
|
||||||
|
x_sample = x_sample.astype(np.uint8)
|
||||||
|
return Image.fromarray(x_sample)
|
||||||
|
|
||||||
|
|
||||||
|
def sample_to_image(samples, index=0, approximation=None):
|
||||||
|
return single_sample_to_image(samples[index], approximation)
|
||||||
|
|
||||||
|
|
||||||
|
def samples_to_image_grid(samples, approximation=None):
|
||||||
|
return images.image_grid([single_sample_to_image(sample, approximation) for sample in samples])
|
||||||
|
|
||||||
|
|
||||||
|
def store_latent(decoded):
|
||||||
|
state.current_latent = decoded
|
||||||
|
|
||||||
|
if opts.live_previews_enable and opts.show_progress_every_n_steps > 0 and shared.state.sampling_step % opts.show_progress_every_n_steps == 0:
|
||||||
|
if not shared.parallel_processing_allowed:
|
||||||
|
shared.state.assign_current_image(sample_to_image(decoded))
|
||||||
|
|
||||||
|
|
||||||
|
class InterruptedException(BaseException):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class VanillaStableDiffusionSampler:
|
||||||
|
def __init__(self, constructor, sd_model):
|
||||||
|
self.sampler = constructor(sd_model)
|
||||||
|
self.is_plms = hasattr(self.sampler, 'p_sample_plms')
|
||||||
|
self.orig_p_sample_ddim = self.sampler.p_sample_plms if self.is_plms else self.sampler.p_sample_ddim
|
||||||
|
self.mask = None
|
||||||
|
self.nmask = None
|
||||||
|
self.init_latent = None
|
||||||
|
self.sampler_noises = None
|
||||||
|
self.step = 0
|
||||||
|
self.stop_at = None
|
||||||
|
self.eta = None
|
||||||
|
self.default_eta = 0.0
|
||||||
|
self.config = None
|
||||||
|
self.last_latent = None
|
||||||
|
|
||||||
|
self.conditioning_key = sd_model.model.conditioning_key
|
||||||
|
|
||||||
|
def number_of_needed_noises(self, p):
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def launch_sampling(self, steps, func):
|
||||||
|
state.sampling_steps = steps
|
||||||
|
state.sampling_step = 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
return func()
|
||||||
|
except InterruptedException:
|
||||||
|
return self.last_latent
|
||||||
|
|
||||||
|
def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs):
|
||||||
|
if state.interrupted or state.skipped:
|
||||||
|
raise InterruptedException
|
||||||
|
|
||||||
|
if self.stop_at is not None and self.step > self.stop_at:
|
||||||
|
raise InterruptedException
|
||||||
|
|
||||||
|
# Have to unwrap the inpainting conditioning here to perform pre-processing
|
||||||
|
image_conditioning = None
|
||||||
|
if isinstance(cond, dict):
|
||||||
|
image_conditioning = cond["c_concat"][0]
|
||||||
|
cond = cond["c_crossattn"][0]
|
||||||
|
unconditional_conditioning = unconditional_conditioning["c_crossattn"][0]
|
||||||
|
|
||||||
|
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
|
||||||
|
unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step)
|
||||||
|
|
||||||
|
assert all([len(conds) == 1 for conds in conds_list]), 'composition via AND is not supported for DDIM/PLMS samplers'
|
||||||
|
cond = tensor
|
||||||
|
|
||||||
|
# for DDIM, shapes must match, we can't just process cond and uncond independently;
|
||||||
|
# filling unconditional_conditioning with repeats of the last vector to match length is
|
||||||
|
# not 100% correct but should work well enough
|
||||||
|
if unconditional_conditioning.shape[1] < cond.shape[1]:
|
||||||
|
last_vector = unconditional_conditioning[:, -1:]
|
||||||
|
last_vector_repeated = last_vector.repeat([1, cond.shape[1] - unconditional_conditioning.shape[1], 1])
|
||||||
|
unconditional_conditioning = torch.hstack([unconditional_conditioning, last_vector_repeated])
|
||||||
|
elif unconditional_conditioning.shape[1] > cond.shape[1]:
|
||||||
|
unconditional_conditioning = unconditional_conditioning[:, :cond.shape[1]]
|
||||||
|
|
||||||
|
if self.mask is not None:
|
||||||
|
img_orig = self.sampler.model.q_sample(self.init_latent, ts)
|
||||||
|
x_dec = img_orig * self.mask + self.nmask * x_dec
|
||||||
|
|
||||||
|
# Wrap the image conditioning back up since the DDIM code can accept the dict directly.
|
||||||
|
# Note that they need to be lists because it just concatenates them later.
|
||||||
|
if image_conditioning is not None:
|
||||||
|
cond = {"c_concat": [image_conditioning], "c_crossattn": [cond]}
|
||||||
|
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
|
||||||
|
|
||||||
|
res = self.orig_p_sample_ddim(x_dec, cond, ts, unconditional_conditioning=unconditional_conditioning, *args, **kwargs)
|
||||||
|
|
||||||
|
if self.mask is not None:
|
||||||
|
self.last_latent = self.init_latent * self.mask + self.nmask * res[1]
|
||||||
|
else:
|
||||||
|
self.last_latent = res[1]
|
||||||
|
|
||||||
|
store_latent(self.last_latent)
|
||||||
|
|
||||||
|
self.step += 1
|
||||||
|
state.sampling_step = self.step
|
||||||
|
shared.total_tqdm.update()
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
def initialize(self, p):
|
||||||
|
self.eta = p.eta if p.eta is not None else opts.eta_ddim
|
||||||
|
|
||||||
|
for fieldname in ['p_sample_ddim', 'p_sample_plms']:
|
||||||
|
if hasattr(self.sampler, fieldname):
|
||||||
|
setattr(self.sampler, fieldname, self.p_sample_ddim_hook)
|
||||||
|
|
||||||
|
self.mask = p.mask if hasattr(p, 'mask') else None
|
||||||
|
self.nmask = p.nmask if hasattr(p, 'nmask') else None
|
||||||
|
|
||||||
|
def adjust_steps_if_invalid(self, p, num_steps):
|
||||||
|
if (self.config.name == 'DDIM' and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS'):
|
||||||
|
valid_step = 999 / (1000 // num_steps)
|
||||||
|
if valid_step == floor(valid_step):
|
||||||
|
return int(valid_step) + 1
|
||||||
|
|
||||||
|
return num_steps
|
||||||
|
|
||||||
|
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
||||||
|
steps, t_enc = setup_img2img_steps(p, steps)
|
||||||
|
steps = self.adjust_steps_if_invalid(p, steps)
|
||||||
|
self.initialize(p)
|
||||||
|
|
||||||
|
self.sampler.make_schedule(ddim_num_steps=steps, ddim_eta=self.eta, ddim_discretize=p.ddim_discretize, verbose=False)
|
||||||
|
x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(shared.device), noise=noise)
|
||||||
|
|
||||||
|
self.init_latent = x
|
||||||
|
self.last_latent = x
|
||||||
|
self.step = 0
|
||||||
|
|
||||||
|
# Wrap the conditioning models with additional image conditioning for inpainting model
|
||||||
|
if image_conditioning is not None:
|
||||||
|
conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]}
|
||||||
|
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
|
||||||
|
|
||||||
|
samples = self.launch_sampling(t_enc + 1, lambda: self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning))
|
||||||
|
|
||||||
|
return samples
|
||||||
|
|
||||||
|
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
||||||
|
self.initialize(p)
|
||||||
|
|
||||||
|
self.init_latent = None
|
||||||
|
self.last_latent = x
|
||||||
|
self.step = 0
|
||||||
|
|
||||||
|
steps = self.adjust_steps_if_invalid(p, steps or p.steps)
|
||||||
|
|
||||||
|
# Wrap the conditioning models with additional image conditioning for inpainting model
|
||||||
|
# dummy_for_plms is needed because PLMS code checks the first item in the dict to have the right shape
|
||||||
|
if image_conditioning is not None:
|
||||||
|
conditioning = {"dummy_for_plms": np.zeros((conditioning.shape[0],)), "c_crossattn": [conditioning], "c_concat": [image_conditioning]}
|
||||||
|
unconditional_conditioning = {"c_crossattn": [unconditional_conditioning], "c_concat": [image_conditioning]}
|
||||||
|
|
||||||
|
samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
|
||||||
|
|
||||||
|
return samples_ddim
|
||||||
|
|
||||||
|
|
||||||
|
class CFGDenoiser(torch.nn.Module):
|
||||||
|
def __init__(self, model):
|
||||||
|
super().__init__()
|
||||||
|
self.inner_model = model
|
||||||
|
self.mask = None
|
||||||
|
self.nmask = None
|
||||||
|
self.init_latent = None
|
||||||
|
self.step = 0
|
||||||
|
|
||||||
|
def combine_denoised(self, x_out, conds_list, uncond, cond_scale):
|
||||||
|
denoised_uncond = x_out[-uncond.shape[0]:]
|
||||||
|
denoised = torch.clone(denoised_uncond)
|
||||||
|
|
||||||
|
for i, conds in enumerate(conds_list):
|
||||||
|
for cond_index, weight in conds:
|
||||||
|
denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale)
|
||||||
|
|
||||||
|
return denoised
|
||||||
|
|
||||||
|
def forward(self, x, sigma, uncond, cond, cond_scale, image_cond):
|
||||||
|
if state.interrupted or state.skipped:
|
||||||
|
raise InterruptedException
|
||||||
|
|
||||||
|
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
|
||||||
|
uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
|
||||||
|
|
||||||
|
batch_size = len(conds_list)
|
||||||
|
repeats = [len(conds_list[i]) for i in range(batch_size)]
|
||||||
|
|
||||||
|
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
|
||||||
|
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond])
|
||||||
|
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
|
||||||
|
|
||||||
|
denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps)
|
||||||
|
cfg_denoiser_callback(denoiser_params)
|
||||||
|
x_in = denoiser_params.x
|
||||||
|
image_cond_in = denoiser_params.image_cond
|
||||||
|
sigma_in = denoiser_params.sigma
|
||||||
|
|
||||||
|
if tensor.shape[1] == uncond.shape[1]:
|
||||||
|
cond_in = torch.cat([tensor, uncond])
|
||||||
|
|
||||||
|
if shared.batch_cond_uncond:
|
||||||
|
x_out = self.inner_model(x_in, sigma_in, cond={"c_crossattn": [cond_in], "c_concat": [image_cond_in]})
|
||||||
|
else:
|
||||||
|
x_out = torch.zeros_like(x_in)
|
||||||
|
for batch_offset in range(0, x_out.shape[0], batch_size):
|
||||||
|
a = batch_offset
|
||||||
|
b = a + batch_size
|
||||||
|
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": [cond_in[a:b]], "c_concat": [image_cond_in[a:b]]})
|
||||||
|
else:
|
||||||
|
x_out = torch.zeros_like(x_in)
|
||||||
|
batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size
|
||||||
|
for batch_offset in range(0, tensor.shape[0], batch_size):
|
||||||
|
a = batch_offset
|
||||||
|
b = min(a + batch_size, tensor.shape[0])
|
||||||
|
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": [tensor[a:b]], "c_concat": [image_cond_in[a:b]]})
|
||||||
|
|
||||||
|
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond={"c_crossattn": [uncond], "c_concat": [image_cond_in[-uncond.shape[0]:]]})
|
||||||
|
|
||||||
|
devices.test_for_nans(x_out, "unet")
|
||||||
|
|
||||||
|
if opts.live_preview_content == "Prompt":
|
||||||
|
store_latent(x_out[0:uncond.shape[0]])
|
||||||
|
elif opts.live_preview_content == "Negative prompt":
|
||||||
|
store_latent(x_out[-uncond.shape[0]:])
|
||||||
|
|
||||||
|
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
|
||||||
|
|
||||||
|
if self.mask is not None:
|
||||||
|
denoised = self.init_latent * self.mask + self.nmask * denoised
|
||||||
|
|
||||||
|
self.step += 1
|
||||||
|
|
||||||
|
return denoised
|
||||||
|
|
||||||
|
|
||||||
|
class TorchHijack:
|
||||||
|
def __init__(self, sampler_noises):
|
||||||
|
# Using a deque to efficiently receive the sampler_noises in the same order as the previous index-based
|
||||||
|
# implementation.
|
||||||
|
self.sampler_noises = deque(sampler_noises)
|
||||||
|
|
||||||
|
def __getattr__(self, item):
|
||||||
|
if item == 'randn_like':
|
||||||
|
return self.randn_like
|
||||||
|
|
||||||
|
if hasattr(torch, item):
|
||||||
|
return getattr(torch, item)
|
||||||
|
|
||||||
|
raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item))
|
||||||
|
|
||||||
|
def randn_like(self, x):
|
||||||
|
if self.sampler_noises:
|
||||||
|
noise = self.sampler_noises.popleft()
|
||||||
|
if noise.shape == x.shape:
|
||||||
|
return noise
|
||||||
|
|
||||||
|
if x.device.type == 'mps':
|
||||||
|
return torch.randn_like(x, device=devices.cpu).to(x.device)
|
||||||
|
else:
|
||||||
|
return torch.randn_like(x)
|
||||||
|
|
||||||
|
|
||||||
|
# MPS fix for randn in torchsde
|
||||||
|
def torchsde_randn(size, dtype, device, seed):
|
||||||
|
if device.type == 'mps':
|
||||||
|
generator = torch.Generator(devices.cpu).manual_seed(int(seed))
|
||||||
|
return torch.randn(size, dtype=dtype, device=devices.cpu, generator=generator).to(device)
|
||||||
|
else:
|
||||||
|
generator = torch.Generator(device).manual_seed(int(seed))
|
||||||
|
return torch.randn(size, dtype=dtype, device=device, generator=generator)
|
||||||
|
|
||||||
|
|
||||||
|
torchsde._brownian.brownian_interval._randn = torchsde_randn
|
||||||
|
|
||||||
|
|
||||||
|
class KDiffusionSampler:
|
||||||
|
def __init__(self, funcname, sd_model):
|
||||||
|
denoiser = k_diffusion.external.CompVisVDenoiser if sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser
|
||||||
|
|
||||||
|
self.model_wrap = denoiser(sd_model, quantize=shared.opts.enable_quantization)
|
||||||
|
self.funcname = funcname
|
||||||
|
self.func = getattr(k_diffusion.sampling, self.funcname)
|
||||||
|
self.extra_params = sampler_extra_params.get(funcname, [])
|
||||||
|
self.model_wrap_cfg = CFGDenoiser(self.model_wrap)
|
||||||
|
self.sampler_noises = None
|
||||||
|
self.stop_at = None
|
||||||
|
self.eta = None
|
||||||
|
self.default_eta = 1.0
|
||||||
|
self.config = None
|
||||||
|
self.last_latent = None
|
||||||
|
|
||||||
|
self.conditioning_key = sd_model.model.conditioning_key
|
||||||
|
|
||||||
|
def callback_state(self, d):
|
||||||
|
step = d['i']
|
||||||
|
latent = d["denoised"]
|
||||||
|
if opts.live_preview_content == "Combined":
|
||||||
|
store_latent(latent)
|
||||||
|
self.last_latent = latent
|
||||||
|
|
||||||
|
if self.stop_at is not None and step > self.stop_at:
|
||||||
|
raise InterruptedException
|
||||||
|
|
||||||
|
state.sampling_step = step
|
||||||
|
shared.total_tqdm.update()
|
||||||
|
|
||||||
|
def launch_sampling(self, steps, func):
|
||||||
|
state.sampling_steps = steps
|
||||||
|
state.sampling_step = 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
return func()
|
||||||
|
except InterruptedException:
|
||||||
|
return self.last_latent
|
||||||
|
|
||||||
|
def number_of_needed_noises(self, p):
|
||||||
|
return p.steps
|
||||||
|
|
||||||
|
def initialize(self, p):
|
||||||
|
self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None
|
||||||
|
self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None
|
||||||
|
self.model_wrap_cfg.step = 0
|
||||||
|
self.eta = p.eta or opts.eta_ancestral
|
||||||
|
|
||||||
|
k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else [])
|
||||||
|
|
||||||
|
extra_params_kwargs = {}
|
||||||
|
for param_name in self.extra_params:
|
||||||
|
if hasattr(p, param_name) and param_name in inspect.signature(self.func).parameters:
|
||||||
|
extra_params_kwargs[param_name] = getattr(p, param_name)
|
||||||
|
|
||||||
|
if 'eta' in inspect.signature(self.func).parameters:
|
||||||
|
extra_params_kwargs['eta'] = self.eta
|
||||||
|
|
||||||
|
return extra_params_kwargs
|
||||||
|
|
||||||
|
def get_sigmas(self, p, steps):
|
||||||
|
discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False)
|
||||||
|
if opts.always_discard_next_to_last_sigma and not discard_next_to_last_sigma:
|
||||||
|
discard_next_to_last_sigma = True
|
||||||
|
p.extra_generation_params["Discard penultimate sigma"] = True
|
||||||
|
|
||||||
|
steps += 1 if discard_next_to_last_sigma else 0
|
||||||
|
|
||||||
|
if p.sampler_noise_scheduler_override:
|
||||||
|
sigmas = p.sampler_noise_scheduler_override(steps)
|
||||||
|
elif self.config is not None and self.config.options.get('scheduler', None) == 'karras':
|
||||||
|
sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item())
|
||||||
|
|
||||||
|
sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, device=shared.device)
|
||||||
|
else:
|
||||||
|
sigmas = self.model_wrap.get_sigmas(steps)
|
||||||
|
|
||||||
|
if discard_next_to_last_sigma:
|
||||||
|
sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])
|
||||||
|
|
||||||
|
return sigmas
|
||||||
|
|
||||||
|
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
||||||
|
steps, t_enc = setup_img2img_steps(p, steps)
|
||||||
|
|
||||||
|
sigmas = self.get_sigmas(p, steps)
|
||||||
|
|
||||||
|
sigma_sched = sigmas[steps - t_enc - 1:]
|
||||||
|
xi = x + noise * sigma_sched[0]
|
||||||
|
|
||||||
|
extra_params_kwargs = self.initialize(p)
|
||||||
|
if 'sigma_min' in inspect.signature(self.func).parameters:
|
||||||
|
## last sigma is zero which isn't allowed by DPM Fast & Adaptive so taking value before last
|
||||||
|
extra_params_kwargs['sigma_min'] = sigma_sched[-2]
|
||||||
|
if 'sigma_max' in inspect.signature(self.func).parameters:
|
||||||
|
extra_params_kwargs['sigma_max'] = sigma_sched[0]
|
||||||
|
if 'n' in inspect.signature(self.func).parameters:
|
||||||
|
extra_params_kwargs['n'] = len(sigma_sched) - 1
|
||||||
|
if 'sigma_sched' in inspect.signature(self.func).parameters:
|
||||||
|
extra_params_kwargs['sigma_sched'] = sigma_sched
|
||||||
|
if 'sigmas' in inspect.signature(self.func).parameters:
|
||||||
|
extra_params_kwargs['sigmas'] = sigma_sched
|
||||||
|
|
||||||
|
self.model_wrap_cfg.init_latent = x
|
||||||
|
self.last_latent = x
|
||||||
|
|
||||||
|
samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args={
|
||||||
|
'cond': conditioning,
|
||||||
|
'image_cond': image_conditioning,
|
||||||
|
'uncond': unconditional_conditioning,
|
||||||
|
'cond_scale': p.cfg_scale
|
||||||
|
}, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
||||||
|
|
||||||
|
return samples
|
||||||
|
|
||||||
|
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning = None):
|
||||||
|
steps = steps or p.steps
|
||||||
|
|
||||||
|
sigmas = self.get_sigmas(p, steps)
|
||||||
|
|
||||||
|
x = x * sigmas[0]
|
||||||
|
|
||||||
|
extra_params_kwargs = self.initialize(p)
|
||||||
|
if 'sigma_min' in inspect.signature(self.func).parameters:
|
||||||
|
extra_params_kwargs['sigma_min'] = self.model_wrap.sigmas[0].item()
|
||||||
|
extra_params_kwargs['sigma_max'] = self.model_wrap.sigmas[-1].item()
|
||||||
|
if 'n' in inspect.signature(self.func).parameters:
|
||||||
|
extra_params_kwargs['n'] = steps
|
||||||
|
else:
|
||||||
|
extra_params_kwargs['sigmas'] = sigmas
|
||||||
|
|
||||||
|
self.last_latent = x
|
||||||
|
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={
|
||||||
|
'cond': conditioning,
|
||||||
|
'image_cond': image_conditioning,
|
||||||
|
'uncond': unconditional_conditioning,
|
||||||
|
'cond_scale': p.cfg_scale
|
||||||
|
}, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
||||||
|
|
||||||
|
return samples
|
||||||
|
|
||||||
|
|
|
@ -1,62 +0,0 @@
|
||||||
from collections import namedtuple
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from PIL import Image
|
|
||||||
from modules import devices, processing, images, sd_vae_approx
|
|
||||||
|
|
||||||
from modules.shared import opts, state
|
|
||||||
import modules.shared as shared
|
|
||||||
|
|
||||||
SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
|
|
||||||
|
|
||||||
|
|
||||||
def setup_img2img_steps(p, steps=None):
|
|
||||||
if opts.img2img_fix_steps or steps is not None:
|
|
||||||
requested_steps = (steps or p.steps)
|
|
||||||
steps = int(requested_steps / min(p.denoising_strength, 0.999)) if p.denoising_strength > 0 else 0
|
|
||||||
t_enc = requested_steps - 1
|
|
||||||
else:
|
|
||||||
steps = p.steps
|
|
||||||
t_enc = int(min(p.denoising_strength, 0.999) * steps)
|
|
||||||
|
|
||||||
return steps, t_enc
|
|
||||||
|
|
||||||
|
|
||||||
approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2}
|
|
||||||
|
|
||||||
|
|
||||||
def single_sample_to_image(sample, approximation=None):
|
|
||||||
if approximation is None:
|
|
||||||
approximation = approximation_indexes.get(opts.show_progress_type, 0)
|
|
||||||
|
|
||||||
if approximation == 2:
|
|
||||||
x_sample = sd_vae_approx.cheap_approximation(sample)
|
|
||||||
elif approximation == 1:
|
|
||||||
x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach()
|
|
||||||
else:
|
|
||||||
x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0]
|
|
||||||
|
|
||||||
x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
|
|
||||||
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
|
|
||||||
x_sample = x_sample.astype(np.uint8)
|
|
||||||
return Image.fromarray(x_sample)
|
|
||||||
|
|
||||||
|
|
||||||
def sample_to_image(samples, index=0, approximation=None):
|
|
||||||
return single_sample_to_image(samples[index], approximation)
|
|
||||||
|
|
||||||
|
|
||||||
def samples_to_image_grid(samples, approximation=None):
|
|
||||||
return images.image_grid([single_sample_to_image(sample, approximation) for sample in samples])
|
|
||||||
|
|
||||||
|
|
||||||
def store_latent(decoded):
|
|
||||||
state.current_latent = decoded
|
|
||||||
|
|
||||||
if opts.live_previews_enable and opts.show_progress_every_n_steps > 0 and shared.state.sampling_step % opts.show_progress_every_n_steps == 0:
|
|
||||||
if not shared.parallel_processing_allowed:
|
|
||||||
shared.state.assign_current_image(sample_to_image(decoded))
|
|
||||||
|
|
||||||
|
|
||||||
class InterruptedException(BaseException):
|
|
||||||
pass
|
|
|
@ -1,160 +0,0 @@
|
||||||
import math
|
|
||||||
import ldm.models.diffusion.ddim
|
|
||||||
import ldm.models.diffusion.plms
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from modules.shared import state
|
|
||||||
from modules import sd_samplers_common, prompt_parser, shared
|
|
||||||
|
|
||||||
|
|
||||||
samplers_data_compvis = [
|
|
||||||
sd_samplers_common.SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {}),
|
|
||||||
sd_samplers_common.SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {}),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class VanillaStableDiffusionSampler:
|
|
||||||
def __init__(self, constructor, sd_model):
|
|
||||||
self.sampler = constructor(sd_model)
|
|
||||||
self.is_plms = hasattr(self.sampler, 'p_sample_plms')
|
|
||||||
self.orig_p_sample_ddim = self.sampler.p_sample_plms if self.is_plms else self.sampler.p_sample_ddim
|
|
||||||
self.mask = None
|
|
||||||
self.nmask = None
|
|
||||||
self.init_latent = None
|
|
||||||
self.sampler_noises = None
|
|
||||||
self.step = 0
|
|
||||||
self.stop_at = None
|
|
||||||
self.eta = None
|
|
||||||
self.config = None
|
|
||||||
self.last_latent = None
|
|
||||||
|
|
||||||
self.conditioning_key = sd_model.model.conditioning_key
|
|
||||||
|
|
||||||
def number_of_needed_noises(self, p):
|
|
||||||
return 0
|
|
||||||
|
|
||||||
def launch_sampling(self, steps, func):
|
|
||||||
state.sampling_steps = steps
|
|
||||||
state.sampling_step = 0
|
|
||||||
|
|
||||||
try:
|
|
||||||
return func()
|
|
||||||
except sd_samplers_common.InterruptedException:
|
|
||||||
return self.last_latent
|
|
||||||
|
|
||||||
def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs):
|
|
||||||
if state.interrupted or state.skipped:
|
|
||||||
raise sd_samplers_common.InterruptedException
|
|
||||||
|
|
||||||
if self.stop_at is not None and self.step > self.stop_at:
|
|
||||||
raise sd_samplers_common.InterruptedException
|
|
||||||
|
|
||||||
# Have to unwrap the inpainting conditioning here to perform pre-processing
|
|
||||||
image_conditioning = None
|
|
||||||
if isinstance(cond, dict):
|
|
||||||
image_conditioning = cond["c_concat"][0]
|
|
||||||
cond = cond["c_crossattn"][0]
|
|
||||||
unconditional_conditioning = unconditional_conditioning["c_crossattn"][0]
|
|
||||||
|
|
||||||
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
|
|
||||||
unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step)
|
|
||||||
|
|
||||||
assert all([len(conds) == 1 for conds in conds_list]), 'composition via AND is not supported for DDIM/PLMS samplers'
|
|
||||||
cond = tensor
|
|
||||||
|
|
||||||
# for DDIM, shapes must match, we can't just process cond and uncond independently;
|
|
||||||
# filling unconditional_conditioning with repeats of the last vector to match length is
|
|
||||||
# not 100% correct but should work well enough
|
|
||||||
if unconditional_conditioning.shape[1] < cond.shape[1]:
|
|
||||||
last_vector = unconditional_conditioning[:, -1:]
|
|
||||||
last_vector_repeated = last_vector.repeat([1, cond.shape[1] - unconditional_conditioning.shape[1], 1])
|
|
||||||
unconditional_conditioning = torch.hstack([unconditional_conditioning, last_vector_repeated])
|
|
||||||
elif unconditional_conditioning.shape[1] > cond.shape[1]:
|
|
||||||
unconditional_conditioning = unconditional_conditioning[:, :cond.shape[1]]
|
|
||||||
|
|
||||||
if self.mask is not None:
|
|
||||||
img_orig = self.sampler.model.q_sample(self.init_latent, ts)
|
|
||||||
x_dec = img_orig * self.mask + self.nmask * x_dec
|
|
||||||
|
|
||||||
# Wrap the image conditioning back up since the DDIM code can accept the dict directly.
|
|
||||||
# Note that they need to be lists because it just concatenates them later.
|
|
||||||
if image_conditioning is not None:
|
|
||||||
cond = {"c_concat": [image_conditioning], "c_crossattn": [cond]}
|
|
||||||
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
|
|
||||||
|
|
||||||
res = self.orig_p_sample_ddim(x_dec, cond, ts, unconditional_conditioning=unconditional_conditioning, *args, **kwargs)
|
|
||||||
|
|
||||||
if self.mask is not None:
|
|
||||||
self.last_latent = self.init_latent * self.mask + self.nmask * res[1]
|
|
||||||
else:
|
|
||||||
self.last_latent = res[1]
|
|
||||||
|
|
||||||
sd_samplers_common.store_latent(self.last_latent)
|
|
||||||
|
|
||||||
self.step += 1
|
|
||||||
state.sampling_step = self.step
|
|
||||||
shared.total_tqdm.update()
|
|
||||||
|
|
||||||
return res
|
|
||||||
|
|
||||||
def initialize(self, p):
|
|
||||||
self.eta = p.eta if p.eta is not None else shared.opts.eta_ddim
|
|
||||||
if self.eta != 0.0:
|
|
||||||
p.extra_generation_params["Eta DDIM"] = self.eta
|
|
||||||
|
|
||||||
for fieldname in ['p_sample_ddim', 'p_sample_plms']:
|
|
||||||
if hasattr(self.sampler, fieldname):
|
|
||||||
setattr(self.sampler, fieldname, self.p_sample_ddim_hook)
|
|
||||||
|
|
||||||
self.mask = p.mask if hasattr(p, 'mask') else None
|
|
||||||
self.nmask = p.nmask if hasattr(p, 'nmask') else None
|
|
||||||
|
|
||||||
def adjust_steps_if_invalid(self, p, num_steps):
|
|
||||||
if (self.config.name == 'DDIM' and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS'):
|
|
||||||
valid_step = 999 / (1000 // num_steps)
|
|
||||||
if valid_step == math.floor(valid_step):
|
|
||||||
return int(valid_step) + 1
|
|
||||||
|
|
||||||
return num_steps
|
|
||||||
|
|
||||||
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
|
||||||
steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps)
|
|
||||||
steps = self.adjust_steps_if_invalid(p, steps)
|
|
||||||
self.initialize(p)
|
|
||||||
|
|
||||||
self.sampler.make_schedule(ddim_num_steps=steps, ddim_eta=self.eta, ddim_discretize=p.ddim_discretize, verbose=False)
|
|
||||||
x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(shared.device), noise=noise)
|
|
||||||
|
|
||||||
self.init_latent = x
|
|
||||||
self.last_latent = x
|
|
||||||
self.step = 0
|
|
||||||
|
|
||||||
# Wrap the conditioning models with additional image conditioning for inpainting model
|
|
||||||
if image_conditioning is not None:
|
|
||||||
conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]}
|
|
||||||
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
|
|
||||||
|
|
||||||
samples = self.launch_sampling(t_enc + 1, lambda: self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning))
|
|
||||||
|
|
||||||
return samples
|
|
||||||
|
|
||||||
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
|
||||||
self.initialize(p)
|
|
||||||
|
|
||||||
self.init_latent = None
|
|
||||||
self.last_latent = x
|
|
||||||
self.step = 0
|
|
||||||
|
|
||||||
steps = self.adjust_steps_if_invalid(p, steps or p.steps)
|
|
||||||
|
|
||||||
# Wrap the conditioning models with additional image conditioning for inpainting model
|
|
||||||
# dummy_for_plms is needed because PLMS code checks the first item in the dict to have the right shape
|
|
||||||
if image_conditioning is not None:
|
|
||||||
conditioning = {"dummy_for_plms": np.zeros((conditioning.shape[0],)), "c_crossattn": [conditioning], "c_concat": [image_conditioning]}
|
|
||||||
unconditional_conditioning = {"c_crossattn": [unconditional_conditioning], "c_concat": [image_conditioning]}
|
|
||||||
|
|
||||||
samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
|
|
||||||
|
|
||||||
return samples_ddim
|
|
|
@ -1,331 +0,0 @@
|
||||||
from collections import deque
|
|
||||||
import torch
|
|
||||||
import inspect
|
|
||||||
import einops
|
|
||||||
import k_diffusion.sampling
|
|
||||||
from modules import prompt_parser, devices, sd_samplers_common
|
|
||||||
|
|
||||||
from modules.shared import opts, state
|
|
||||||
import modules.shared as shared
|
|
||||||
from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback
|
|
||||||
|
|
||||||
samplers_k_diffusion = [
|
|
||||||
('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {}),
|
|
||||||
('Euler', 'sample_euler', ['k_euler'], {}),
|
|
||||||
('LMS', 'sample_lms', ['k_lms'], {}),
|
|
||||||
('Heun', 'sample_heun', ['k_heun'], {}),
|
|
||||||
('DPM2', 'sample_dpm_2', ['k_dpm_2'], {'discard_next_to_last_sigma': True}),
|
|
||||||
('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'discard_next_to_last_sigma': True}),
|
|
||||||
('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {}),
|
|
||||||
('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}),
|
|
||||||
('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {}),
|
|
||||||
('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {}),
|
|
||||||
('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {}),
|
|
||||||
('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}),
|
|
||||||
('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True}),
|
|
||||||
('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True}),
|
|
||||||
('DPM++ 2S a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras'}),
|
|
||||||
('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}),
|
|
||||||
('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras'}),
|
|
||||||
]
|
|
||||||
|
|
||||||
samplers_data_k_diffusion = [
|
|
||||||
sd_samplers_common.SamplerData(label, lambda model, funcname=funcname: KDiffusionSampler(funcname, model), aliases, options)
|
|
||||||
for label, funcname, aliases, options in samplers_k_diffusion
|
|
||||||
if hasattr(k_diffusion.sampling, funcname)
|
|
||||||
]
|
|
||||||
|
|
||||||
sampler_extra_params = {
|
|
||||||
'sample_euler': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
|
|
||||||
'sample_heun': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
|
|
||||||
'sample_dpm_2': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class CFGDenoiser(torch.nn.Module):
|
|
||||||
"""
|
|
||||||
Classifier free guidance denoiser. A wrapper for stable diffusion model (specifically for unet)
|
|
||||||
that can take a noisy picture and produce a noise-free picture using two guidances (prompts)
|
|
||||||
instead of one. Originally, the second prompt is just an empty string, but we use non-empty
|
|
||||||
negative prompt.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, model):
|
|
||||||
super().__init__()
|
|
||||||
self.inner_model = model
|
|
||||||
self.mask = None
|
|
||||||
self.nmask = None
|
|
||||||
self.init_latent = None
|
|
||||||
self.step = 0
|
|
||||||
self.image_cfg_scale = None
|
|
||||||
|
|
||||||
def combine_denoised(self, x_out, conds_list, uncond, cond_scale):
|
|
||||||
denoised_uncond = x_out[-uncond.shape[0]:]
|
|
||||||
denoised = torch.clone(denoised_uncond)
|
|
||||||
|
|
||||||
for i, conds in enumerate(conds_list):
|
|
||||||
for cond_index, weight in conds:
|
|
||||||
denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale)
|
|
||||||
|
|
||||||
return denoised
|
|
||||||
|
|
||||||
def combine_denoised_for_edit_model(self, x_out, cond_scale):
|
|
||||||
out_cond, out_img_cond, out_uncond = x_out.chunk(3)
|
|
||||||
denoised = out_uncond + cond_scale * (out_cond - out_img_cond) + self.image_cfg_scale * (out_img_cond - out_uncond)
|
|
||||||
|
|
||||||
return denoised
|
|
||||||
|
|
||||||
def forward(self, x, sigma, uncond, cond, cond_scale, image_cond):
|
|
||||||
if state.interrupted or state.skipped:
|
|
||||||
raise sd_samplers_common.InterruptedException
|
|
||||||
|
|
||||||
# at self.image_cfg_scale == 1.0 produced results for edit model are the same as with normal sampling,
|
|
||||||
# so is_edit_model is set to False to support AND composition.
|
|
||||||
is_edit_model = shared.sd_model.cond_stage_key == "edit" and self.image_cfg_scale is not None and self.image_cfg_scale != 1.0
|
|
||||||
|
|
||||||
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
|
|
||||||
uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
|
|
||||||
|
|
||||||
assert not is_edit_model or all([len(conds) == 1 for conds in conds_list]), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"
|
|
||||||
|
|
||||||
batch_size = len(conds_list)
|
|
||||||
repeats = [len(conds_list[i]) for i in range(batch_size)]
|
|
||||||
|
|
||||||
if not is_edit_model:
|
|
||||||
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
|
|
||||||
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
|
|
||||||
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond])
|
|
||||||
else:
|
|
||||||
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x] + [x])
|
|
||||||
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma])
|
|
||||||
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond] + [torch.zeros_like(self.init_latent)])
|
|
||||||
|
|
||||||
denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps)
|
|
||||||
cfg_denoiser_callback(denoiser_params)
|
|
||||||
x_in = denoiser_params.x
|
|
||||||
image_cond_in = denoiser_params.image_cond
|
|
||||||
sigma_in = denoiser_params.sigma
|
|
||||||
|
|
||||||
if tensor.shape[1] == uncond.shape[1]:
|
|
||||||
if not is_edit_model:
|
|
||||||
cond_in = torch.cat([tensor, uncond])
|
|
||||||
else:
|
|
||||||
cond_in = torch.cat([tensor, uncond, uncond])
|
|
||||||
|
|
||||||
if shared.batch_cond_uncond:
|
|
||||||
x_out = self.inner_model(x_in, sigma_in, cond={"c_crossattn": [cond_in], "c_concat": [image_cond_in]})
|
|
||||||
else:
|
|
||||||
x_out = torch.zeros_like(x_in)
|
|
||||||
for batch_offset in range(0, x_out.shape[0], batch_size):
|
|
||||||
a = batch_offset
|
|
||||||
b = a + batch_size
|
|
||||||
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": [cond_in[a:b]], "c_concat": [image_cond_in[a:b]]})
|
|
||||||
else:
|
|
||||||
x_out = torch.zeros_like(x_in)
|
|
||||||
batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size
|
|
||||||
for batch_offset in range(0, tensor.shape[0], batch_size):
|
|
||||||
a = batch_offset
|
|
||||||
b = min(a + batch_size, tensor.shape[0])
|
|
||||||
|
|
||||||
if not is_edit_model:
|
|
||||||
c_crossattn = [tensor[a:b]]
|
|
||||||
else:
|
|
||||||
c_crossattn = torch.cat([tensor[a:b]], uncond)
|
|
||||||
|
|
||||||
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": c_crossattn, "c_concat": [image_cond_in[a:b]]})
|
|
||||||
|
|
||||||
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond={"c_crossattn": [uncond], "c_concat": [image_cond_in[-uncond.shape[0]:]]})
|
|
||||||
|
|
||||||
devices.test_for_nans(x_out, "unet")
|
|
||||||
|
|
||||||
if opts.live_preview_content == "Prompt":
|
|
||||||
sd_samplers_common.store_latent(x_out[0:uncond.shape[0]])
|
|
||||||
elif opts.live_preview_content == "Negative prompt":
|
|
||||||
sd_samplers_common.store_latent(x_out[-uncond.shape[0]:])
|
|
||||||
|
|
||||||
if not is_edit_model:
|
|
||||||
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
|
|
||||||
else:
|
|
||||||
denoised = self.combine_denoised_for_edit_model(x_out, cond_scale)
|
|
||||||
|
|
||||||
if self.mask is not None:
|
|
||||||
denoised = self.init_latent * self.mask + self.nmask * denoised
|
|
||||||
|
|
||||||
self.step += 1
|
|
||||||
|
|
||||||
return denoised
|
|
||||||
|
|
||||||
|
|
||||||
class TorchHijack:
|
|
||||||
def __init__(self, sampler_noises):
|
|
||||||
# Using a deque to efficiently receive the sampler_noises in the same order as the previous index-based
|
|
||||||
# implementation.
|
|
||||||
self.sampler_noises = deque(sampler_noises)
|
|
||||||
|
|
||||||
def __getattr__(self, item):
|
|
||||||
if item == 'randn_like':
|
|
||||||
return self.randn_like
|
|
||||||
|
|
||||||
if hasattr(torch, item):
|
|
||||||
return getattr(torch, item)
|
|
||||||
|
|
||||||
raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item))
|
|
||||||
|
|
||||||
def randn_like(self, x):
|
|
||||||
if self.sampler_noises:
|
|
||||||
noise = self.sampler_noises.popleft()
|
|
||||||
if noise.shape == x.shape:
|
|
||||||
return noise
|
|
||||||
|
|
||||||
if x.device.type == 'mps':
|
|
||||||
return torch.randn_like(x, device=devices.cpu).to(x.device)
|
|
||||||
else:
|
|
||||||
return torch.randn_like(x)
|
|
||||||
|
|
||||||
|
|
||||||
class KDiffusionSampler:
|
|
||||||
def __init__(self, funcname, sd_model):
|
|
||||||
denoiser = k_diffusion.external.CompVisVDenoiser if sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser
|
|
||||||
|
|
||||||
self.model_wrap = denoiser(sd_model, quantize=shared.opts.enable_quantization)
|
|
||||||
self.funcname = funcname
|
|
||||||
self.func = getattr(k_diffusion.sampling, self.funcname)
|
|
||||||
self.extra_params = sampler_extra_params.get(funcname, [])
|
|
||||||
self.model_wrap_cfg = CFGDenoiser(self.model_wrap)
|
|
||||||
self.sampler_noises = None
|
|
||||||
self.stop_at = None
|
|
||||||
self.eta = None
|
|
||||||
self.config = None
|
|
||||||
self.last_latent = None
|
|
||||||
|
|
||||||
self.conditioning_key = sd_model.model.conditioning_key
|
|
||||||
|
|
||||||
def callback_state(self, d):
|
|
||||||
step = d['i']
|
|
||||||
latent = d["denoised"]
|
|
||||||
if opts.live_preview_content == "Combined":
|
|
||||||
sd_samplers_common.store_latent(latent)
|
|
||||||
self.last_latent = latent
|
|
||||||
|
|
||||||
if self.stop_at is not None and step > self.stop_at:
|
|
||||||
raise sd_samplers_common.InterruptedException
|
|
||||||
|
|
||||||
state.sampling_step = step
|
|
||||||
shared.total_tqdm.update()
|
|
||||||
|
|
||||||
def launch_sampling(self, steps, func):
|
|
||||||
state.sampling_steps = steps
|
|
||||||
state.sampling_step = 0
|
|
||||||
|
|
||||||
try:
|
|
||||||
return func()
|
|
||||||
except sd_samplers_common.InterruptedException:
|
|
||||||
return self.last_latent
|
|
||||||
|
|
||||||
def number_of_needed_noises(self, p):
|
|
||||||
return p.steps
|
|
||||||
|
|
||||||
def initialize(self, p):
|
|
||||||
self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None
|
|
||||||
self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None
|
|
||||||
self.model_wrap_cfg.step = 0
|
|
||||||
self.model_wrap_cfg.image_cfg_scale = getattr(p, 'image_cfg_scale', None)
|
|
||||||
self.eta = p.eta if p.eta is not None else opts.eta_ancestral
|
|
||||||
|
|
||||||
k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else [])
|
|
||||||
|
|
||||||
extra_params_kwargs = {}
|
|
||||||
for param_name in self.extra_params:
|
|
||||||
if hasattr(p, param_name) and param_name in inspect.signature(self.func).parameters:
|
|
||||||
extra_params_kwargs[param_name] = getattr(p, param_name)
|
|
||||||
|
|
||||||
if 'eta' in inspect.signature(self.func).parameters:
|
|
||||||
if self.eta != 1.0:
|
|
||||||
p.extra_generation_params["Eta"] = self.eta
|
|
||||||
|
|
||||||
extra_params_kwargs['eta'] = self.eta
|
|
||||||
|
|
||||||
return extra_params_kwargs
|
|
||||||
|
|
||||||
def get_sigmas(self, p, steps):
|
|
||||||
discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False)
|
|
||||||
if opts.always_discard_next_to_last_sigma and not discard_next_to_last_sigma:
|
|
||||||
discard_next_to_last_sigma = True
|
|
||||||
p.extra_generation_params["Discard penultimate sigma"] = True
|
|
||||||
|
|
||||||
steps += 1 if discard_next_to_last_sigma else 0
|
|
||||||
|
|
||||||
if p.sampler_noise_scheduler_override:
|
|
||||||
sigmas = p.sampler_noise_scheduler_override(steps)
|
|
||||||
elif self.config is not None and self.config.options.get('scheduler', None) == 'karras':
|
|
||||||
sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item())
|
|
||||||
|
|
||||||
sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, device=shared.device)
|
|
||||||
else:
|
|
||||||
sigmas = self.model_wrap.get_sigmas(steps)
|
|
||||||
|
|
||||||
if discard_next_to_last_sigma:
|
|
||||||
sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])
|
|
||||||
|
|
||||||
return sigmas
|
|
||||||
|
|
||||||
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
|
||||||
steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps)
|
|
||||||
|
|
||||||
sigmas = self.get_sigmas(p, steps)
|
|
||||||
|
|
||||||
sigma_sched = sigmas[steps - t_enc - 1:]
|
|
||||||
xi = x + noise * sigma_sched[0]
|
|
||||||
|
|
||||||
extra_params_kwargs = self.initialize(p)
|
|
||||||
if 'sigma_min' in inspect.signature(self.func).parameters:
|
|
||||||
## last sigma is zero which isn't allowed by DPM Fast & Adaptive so taking value before last
|
|
||||||
extra_params_kwargs['sigma_min'] = sigma_sched[-2]
|
|
||||||
if 'sigma_max' in inspect.signature(self.func).parameters:
|
|
||||||
extra_params_kwargs['sigma_max'] = sigma_sched[0]
|
|
||||||
if 'n' in inspect.signature(self.func).parameters:
|
|
||||||
extra_params_kwargs['n'] = len(sigma_sched) - 1
|
|
||||||
if 'sigma_sched' in inspect.signature(self.func).parameters:
|
|
||||||
extra_params_kwargs['sigma_sched'] = sigma_sched
|
|
||||||
if 'sigmas' in inspect.signature(self.func).parameters:
|
|
||||||
extra_params_kwargs['sigmas'] = sigma_sched
|
|
||||||
|
|
||||||
self.model_wrap_cfg.init_latent = x
|
|
||||||
self.last_latent = x
|
|
||||||
extra_args={
|
|
||||||
'cond': conditioning,
|
|
||||||
'image_cond': image_conditioning,
|
|
||||||
'uncond': unconditional_conditioning,
|
|
||||||
'cond_scale': p.cfg_scale,
|
|
||||||
}
|
|
||||||
|
|
||||||
samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
|
||||||
|
|
||||||
return samples
|
|
||||||
|
|
||||||
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning = None):
|
|
||||||
steps = steps or p.steps
|
|
||||||
|
|
||||||
sigmas = self.get_sigmas(p, steps)
|
|
||||||
|
|
||||||
x = x * sigmas[0]
|
|
||||||
|
|
||||||
extra_params_kwargs = self.initialize(p)
|
|
||||||
if 'sigma_min' in inspect.signature(self.func).parameters:
|
|
||||||
extra_params_kwargs['sigma_min'] = self.model_wrap.sigmas[0].item()
|
|
||||||
extra_params_kwargs['sigma_max'] = self.model_wrap.sigmas[-1].item()
|
|
||||||
if 'n' in inspect.signature(self.func).parameters:
|
|
||||||
extra_params_kwargs['n'] = steps
|
|
||||||
else:
|
|
||||||
extra_params_kwargs['sigmas'] = sigmas
|
|
||||||
|
|
||||||
self.last_latent = x
|
|
||||||
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={
|
|
||||||
'cond': conditioning,
|
|
||||||
'image_cond': image_conditioning,
|
|
||||||
'uncond': unconditional_conditioning,
|
|
||||||
'cond_scale': p.cfg_scale
|
|
||||||
}, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
|
||||||
|
|
||||||
return samples
|
|
||||||
|
|
|
@ -105,8 +105,6 @@ parser.add_argument("--tls-keyfile", type=str, help="Partially enables TLS, requ
|
||||||
parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, requires --tls-keyfile to fully function", default=None)
|
parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, requires --tls-keyfile to fully function", default=None)
|
||||||
parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None)
|
parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None)
|
||||||
parser.add_argument("--gradio-queue", action='store_true', help="Uses gradio queue; experimental option; breaks restart UI button")
|
parser.add_argument("--gradio-queue", action='store_true', help="Uses gradio queue; experimental option; breaks restart UI button")
|
||||||
parser.add_argument("--skip-version-check", action='store_true', help="Do not check versions of torch and xformers")
|
|
||||||
parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False)
|
|
||||||
|
|
||||||
|
|
||||||
script_loading.preload_extensions(extensions.extensions_dir, parser)
|
script_loading.preload_extensions(extensions.extensions_dir, parser)
|
||||||
|
@ -129,13 +127,12 @@ restricted_opts = {
|
||||||
ui_reorder_categories = [
|
ui_reorder_categories = [
|
||||||
"inpaint",
|
"inpaint",
|
||||||
"sampler",
|
"sampler",
|
||||||
"checkboxes",
|
|
||||||
"hires_fix",
|
|
||||||
"dimensions",
|
"dimensions",
|
||||||
"cfg",
|
"cfg",
|
||||||
"seed",
|
"seed",
|
||||||
|
"checkboxes",
|
||||||
|
"hires_fix",
|
||||||
"batch",
|
"batch",
|
||||||
"override_settings",
|
|
||||||
"scripts",
|
"scripts",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -327,7 +324,7 @@ options_templates.update(options_section(('saving-images', "Saving images/grids"
|
||||||
"jpeg_quality": OptionInfo(80, "Quality for saved jpeg images", gr.Slider, {"minimum": 1, "maximum": 100, "step": 1}),
|
"jpeg_quality": OptionInfo(80, "Quality for saved jpeg images", gr.Slider, {"minimum": 1, "maximum": 100, "step": 1}),
|
||||||
"export_for_4chan": OptionInfo(True, "If PNG image is larger than 4MB or any dimension is larger than 4000, downscale and save copy as JPG"),
|
"export_for_4chan": OptionInfo(True, "If PNG image is larger than 4MB or any dimension is larger than 4000, downscale and save copy as JPG"),
|
||||||
|
|
||||||
"use_original_name_batch": OptionInfo(True, "Use original name for output filename during batch process in extras tab"),
|
"use_original_name_batch": OptionInfo(False, "Use original name for output filename during batch process in extras tab"),
|
||||||
"use_upscaler_name_as_suffix": OptionInfo(False, "Use upscaler name as filename suffix in the extras tab"),
|
"use_upscaler_name_as_suffix": OptionInfo(False, "Use upscaler name as filename suffix in the extras tab"),
|
||||||
"save_selected_only": OptionInfo(True, "When using 'Save' button, only save a single selected image"),
|
"save_selected_only": OptionInfo(True, "When using 'Save' button, only save a single selected image"),
|
||||||
"do_not_add_watermark": OptionInfo(False, "Do not add watermark to images"),
|
"do_not_add_watermark": OptionInfo(False, "Do not add watermark to images"),
|
||||||
|
@ -349,10 +346,10 @@ options_templates.update(options_section(('saving-paths', "Paths for saving"), {
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('saving-to-dirs', "Saving to a directory"), {
|
options_templates.update(options_section(('saving-to-dirs', "Saving to a directory"), {
|
||||||
"save_to_dirs": OptionInfo(True, "Save images to a subdirectory"),
|
"save_to_dirs": OptionInfo(False, "Save images to a subdirectory"),
|
||||||
"grid_save_to_dirs": OptionInfo(True, "Save grids to a subdirectory"),
|
"grid_save_to_dirs": OptionInfo(False, "Save grids to a subdirectory"),
|
||||||
"use_save_to_dirs_for_ui": OptionInfo(False, "When using \"Save\" button, save images to a subdirectory"),
|
"use_save_to_dirs_for_ui": OptionInfo(False, "When using \"Save\" button, save images to a subdirectory"),
|
||||||
"directories_filename_pattern": OptionInfo("[date]", "Directory name pattern", component_args=hide_dirs),
|
"directories_filename_pattern": OptionInfo("", "Directory name pattern", component_args=hide_dirs),
|
||||||
"directories_max_prompt_words": OptionInfo(8, "Max prompt words for [prompt_words] pattern", gr.Slider, {"minimum": 1, "maximum": 20, "step": 1, **hide_dirs}),
|
"directories_max_prompt_words": OptionInfo(8, "Max prompt words for [prompt_words] pattern", gr.Slider, {"minimum": 1, "maximum": 20, "step": 1, **hide_dirs}),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
@ -443,7 +440,7 @@ options_templates.update(options_section(('ui', "User interface"), {
|
||||||
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
|
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
|
||||||
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
|
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
|
||||||
"add_model_name_to_info": OptionInfo(True, "Add model name to generation information"),
|
"add_model_name_to_info": OptionInfo(True, "Add model name to generation information"),
|
||||||
"disable_weights_auto_swap": OptionInfo(True, "When reading generation parameters from text into UI (from PNG info or pasted text), do not change the selected model/checkpoint."),
|
"disable_weights_auto_swap": OptionInfo(False, "When reading generation parameters from text into UI (from PNG info or pasted text), do not change the selected model/checkpoint."),
|
||||||
"send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"),
|
"send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"),
|
||||||
"send_size": OptionInfo(True, "Send size when sending prompt or image to another interface"),
|
"send_size": OptionInfo(True, "Send size when sending prompt or image to another interface"),
|
||||||
"font": OptionInfo("", "Font for image grids that have text"),
|
"font": OptionInfo("", "Font for image grids that have text"),
|
||||||
|
@ -608,37 +605,11 @@ class Options:
|
||||||
|
|
||||||
self.data_labels = {k: v for k, v in sorted(settings_items, key=lambda x: section_ids[x[1].section])}
|
self.data_labels = {k: v for k, v in sorted(settings_items, key=lambda x: section_ids[x[1].section])}
|
||||||
|
|
||||||
def cast_value(self, key, value):
|
|
||||||
"""casts an arbitrary to the same type as this setting's value with key
|
|
||||||
Example: cast_value("eta_noise_seed_delta", "12") -> returns 12 (an int rather than str)
|
|
||||||
"""
|
|
||||||
|
|
||||||
if value is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
default_value = self.data_labels[key].default
|
|
||||||
if default_value is None:
|
|
||||||
default_value = getattr(self, key, None)
|
|
||||||
if default_value is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
expected_type = type(default_value)
|
|
||||||
if expected_type == bool and value == "False":
|
|
||||||
value = False
|
|
||||||
else:
|
|
||||||
value = expected_type(value)
|
|
||||||
|
|
||||||
return value
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
opts = Options()
|
opts = Options()
|
||||||
if os.path.exists(config_filename):
|
if os.path.exists(config_filename):
|
||||||
opts.load(config_filename)
|
opts.load(config_filename)
|
||||||
|
|
||||||
settings_components = None
|
|
||||||
"""assinged from ui.py, a mapping on setting anmes to gradio components repsponsible for those settings"""
|
|
||||||
|
|
||||||
latent_upscale_default_mode = "Latent"
|
latent_upscale_default_mode = "Latent"
|
||||||
latent_upscale_modes = {
|
latent_upscale_modes = {
|
||||||
"Latent": {"mode": "bilinear", "antialias": False},
|
"Latent": {"mode": "bilinear", "antialias": False},
|
||||||
|
|
|
@ -112,7 +112,6 @@ class EmbeddingDatabase:
|
||||||
self.skipped_embeddings = {}
|
self.skipped_embeddings = {}
|
||||||
self.expected_shape = -1
|
self.expected_shape = -1
|
||||||
self.embedding_dirs = {}
|
self.embedding_dirs = {}
|
||||||
self.previously_displayed_embeddings = ()
|
|
||||||
|
|
||||||
def add_embedding_dir(self, path):
|
def add_embedding_dir(self, path):
|
||||||
self.embedding_dirs[path] = DirWithTextualInversionEmbeddings(path)
|
self.embedding_dirs[path] = DirWithTextualInversionEmbeddings(path)
|
||||||
|
@ -229,12 +228,9 @@ class EmbeddingDatabase:
|
||||||
self.load_from_dir(embdir)
|
self.load_from_dir(embdir)
|
||||||
embdir.update()
|
embdir.update()
|
||||||
|
|
||||||
displayed_embeddings = (tuple(self.word_embeddings.keys()), tuple(self.skipped_embeddings.keys()))
|
print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}")
|
||||||
if self.previously_displayed_embeddings != displayed_embeddings:
|
if len(self.skipped_embeddings) > 0:
|
||||||
self.previously_displayed_embeddings = displayed_embeddings
|
print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings.keys())}")
|
||||||
print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}")
|
|
||||||
if len(self.skipped_embeddings) > 0:
|
|
||||||
print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings.keys())}")
|
|
||||||
|
|
||||||
def find_embedding_at_position(self, tokens, offset):
|
def find_embedding_at_position(self, tokens, offset):
|
||||||
token = tokens[offset]
|
token = tokens[offset]
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
import modules.scripts
|
import modules.scripts
|
||||||
from modules import sd_samplers
|
from modules import sd_samplers
|
||||||
from modules.generation_parameters_copypaste import create_override_settings_dict
|
|
||||||
from modules.processing import StableDiffusionProcessing, Processed, StableDiffusionProcessingTxt2Img, \
|
from modules.processing import StableDiffusionProcessing, Processed, StableDiffusionProcessingTxt2Img, \
|
||||||
StableDiffusionProcessingImg2Img, process_images
|
StableDiffusionProcessingImg2Img, process_images
|
||||||
from modules.shared import opts, cmd_opts
|
from modules.shared import opts, cmd_opts
|
||||||
|
@ -9,9 +8,7 @@ import modules.processing as processing
|
||||||
from modules.ui import plaintext_to_html
|
from modules.ui import plaintext_to_html
|
||||||
|
|
||||||
|
|
||||||
def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, override_settings_texts, *args):
|
def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, *args):
|
||||||
override_settings = create_override_settings_dict(override_settings_texts)
|
|
||||||
|
|
||||||
p = StableDiffusionProcessingTxt2Img(
|
p = StableDiffusionProcessingTxt2Img(
|
||||||
sd_model=shared.sd_model,
|
sd_model=shared.sd_model,
|
||||||
outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
|
outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
|
||||||
|
@ -41,7 +38,6 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, step
|
||||||
hr_second_pass_steps=hr_second_pass_steps,
|
hr_second_pass_steps=hr_second_pass_steps,
|
||||||
hr_resize_x=hr_resize_x,
|
hr_resize_x=hr_resize_x,
|
||||||
hr_resize_y=hr_resize_y,
|
hr_resize_y=hr_resize_y,
|
||||||
override_settings=override_settings,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
p.scripts = modules.scripts.scripts_txt2img
|
p.scripts = modules.scripts.scripts_txt2img
|
||||||
|
|
|
@ -380,7 +380,6 @@ def apply_setting(key, value):
|
||||||
opts.save(shared.config_filename)
|
opts.save(shared.config_filename)
|
||||||
return getattr(opts, key)
|
return getattr(opts, key)
|
||||||
|
|
||||||
|
|
||||||
def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
|
def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
|
||||||
def refresh():
|
def refresh():
|
||||||
refresh_method()
|
refresh_method()
|
||||||
|
@ -434,18 +433,6 @@ def get_value_for_setting(key):
|
||||||
return gr.update(value=value, **args)
|
return gr.update(value=value, **args)
|
||||||
|
|
||||||
|
|
||||||
def create_override_settings_dropdown(tabname, row):
|
|
||||||
dropdown = gr.Dropdown([], label="Override settings", visible=False, elem_id=f"{tabname}_override_settings", multiselect=True)
|
|
||||||
|
|
||||||
dropdown.change(
|
|
||||||
fn=lambda x: gr.Dropdown.update(visible=len(x) > 0),
|
|
||||||
inputs=[dropdown],
|
|
||||||
outputs=[dropdown],
|
|
||||||
)
|
|
||||||
|
|
||||||
return dropdown
|
|
||||||
|
|
||||||
|
|
||||||
def create_ui():
|
def create_ui():
|
||||||
import modules.img2img
|
import modules.img2img
|
||||||
import modules.txt2img
|
import modules.txt2img
|
||||||
|
@ -479,8 +466,8 @@ def create_ui():
|
||||||
width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="txt2img_width")
|
width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="txt2img_width")
|
||||||
height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="txt2img_height")
|
height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="txt2img_height")
|
||||||
|
|
||||||
res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="txt2img_res_switch_btn")
|
|
||||||
if opts.dimensions_and_batch_together:
|
if opts.dimensions_and_batch_together:
|
||||||
|
res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="txt2img_res_switch_btn")
|
||||||
with gr.Column(elem_id="txt2img_column_batch"):
|
with gr.Column(elem_id="txt2img_column_batch"):
|
||||||
batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count")
|
batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count")
|
||||||
batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size")
|
batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size")
|
||||||
|
@ -516,10 +503,6 @@ def create_ui():
|
||||||
batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count")
|
batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count")
|
||||||
batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size")
|
batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size")
|
||||||
|
|
||||||
elif category == "override_settings":
|
|
||||||
with FormRow(elem_id="txt2img_override_settings_row") as row:
|
|
||||||
override_settings = create_override_settings_dropdown('txt2img', row)
|
|
||||||
|
|
||||||
elif category == "scripts":
|
elif category == "scripts":
|
||||||
with FormGroup(elem_id="txt2img_script_container"):
|
with FormGroup(elem_id="txt2img_script_container"):
|
||||||
custom_inputs = modules.scripts.scripts_txt2img.setup_ui()
|
custom_inputs = modules.scripts.scripts_txt2img.setup_ui()
|
||||||
|
@ -541,6 +524,7 @@ def create_ui():
|
||||||
)
|
)
|
||||||
|
|
||||||
txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples)
|
txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples)
|
||||||
|
parameters_copypaste.bind_buttons({"txt2img": txt2img_paste}, None, txt2img_prompt)
|
||||||
|
|
||||||
connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
|
connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
|
||||||
connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
|
connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
|
||||||
|
@ -571,7 +555,6 @@ def create_ui():
|
||||||
hr_second_pass_steps,
|
hr_second_pass_steps,
|
||||||
hr_resize_x,
|
hr_resize_x,
|
||||||
hr_resize_y,
|
hr_resize_y,
|
||||||
override_settings,
|
|
||||||
] + custom_inputs,
|
] + custom_inputs,
|
||||||
|
|
||||||
outputs=[
|
outputs=[
|
||||||
|
@ -632,9 +615,6 @@ def create_ui():
|
||||||
*modules.scripts.scripts_txt2img.infotext_fields
|
*modules.scripts.scripts_txt2img.infotext_fields
|
||||||
]
|
]
|
||||||
parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields)
|
parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields)
|
||||||
parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(
|
|
||||||
paste_button=txt2img_paste, tabname="txt2img", source_text_component=txt2img_prompt, source_image_component=None, override_settings_component=override_settings,
|
|
||||||
))
|
|
||||||
|
|
||||||
txt2img_preview_params = [
|
txt2img_preview_params = [
|
||||||
txt2img_prompt,
|
txt2img_prompt,
|
||||||
|
@ -757,17 +737,15 @@ def create_ui():
|
||||||
width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width")
|
width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width")
|
||||||
height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="img2img_height")
|
height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="img2img_height")
|
||||||
|
|
||||||
res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="img2img_res_switch_btn")
|
|
||||||
if opts.dimensions_and_batch_together:
|
if opts.dimensions_and_batch_together:
|
||||||
|
res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="img2img_res_switch_btn")
|
||||||
with gr.Column(elem_id="img2img_column_batch"):
|
with gr.Column(elem_id="img2img_column_batch"):
|
||||||
batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count")
|
batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count")
|
||||||
batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size")
|
batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size")
|
||||||
|
|
||||||
elif category == "cfg":
|
elif category == "cfg":
|
||||||
with FormGroup():
|
with FormGroup():
|
||||||
with FormRow():
|
cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="img2img_cfg_scale")
|
||||||
cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="img2img_cfg_scale")
|
|
||||||
image_cfg_scale = gr.Slider(minimum=0, maximum=3.0, step=0.05, label='Image CFG Scale', value=1.5, elem_id="img2img_image_cfg_scale", visible=shared.sd_model and shared.sd_model.cond_stage_key == "edit")
|
|
||||||
denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75, elem_id="img2img_denoising_strength")
|
denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75, elem_id="img2img_denoising_strength")
|
||||||
|
|
||||||
elif category == "seed":
|
elif category == "seed":
|
||||||
|
@ -784,10 +762,6 @@ def create_ui():
|
||||||
batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count")
|
batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count")
|
||||||
batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size")
|
batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size")
|
||||||
|
|
||||||
elif category == "override_settings":
|
|
||||||
with FormRow(elem_id="img2img_override_settings_row") as row:
|
|
||||||
override_settings = create_override_settings_dropdown('img2img', row)
|
|
||||||
|
|
||||||
elif category == "scripts":
|
elif category == "scripts":
|
||||||
with FormGroup(elem_id="img2img_script_container"):
|
with FormGroup(elem_id="img2img_script_container"):
|
||||||
custom_inputs = modules.scripts.scripts_img2img.setup_ui()
|
custom_inputs = modules.scripts.scripts_img2img.setup_ui()
|
||||||
|
@ -822,6 +796,7 @@ def create_ui():
|
||||||
)
|
)
|
||||||
|
|
||||||
img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples)
|
img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples)
|
||||||
|
parameters_copypaste.bind_buttons({"img2img": img2img_paste}, None, img2img_prompt)
|
||||||
|
|
||||||
connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
|
connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
|
||||||
connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
|
connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
|
||||||
|
@ -863,7 +838,6 @@ def create_ui():
|
||||||
batch_count,
|
batch_count,
|
||||||
batch_size,
|
batch_size,
|
||||||
cfg_scale,
|
cfg_scale,
|
||||||
image_cfg_scale,
|
|
||||||
denoising_strength,
|
denoising_strength,
|
||||||
seed,
|
seed,
|
||||||
subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox,
|
subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox,
|
||||||
|
@ -875,8 +849,7 @@ def create_ui():
|
||||||
inpainting_mask_invert,
|
inpainting_mask_invert,
|
||||||
img2img_batch_input_dir,
|
img2img_batch_input_dir,
|
||||||
img2img_batch_output_dir,
|
img2img_batch_output_dir,
|
||||||
img2img_batch_inpaint_mask_dir,
|
img2img_batch_inpaint_mask_dir
|
||||||
override_settings,
|
|
||||||
] + custom_inputs,
|
] + custom_inputs,
|
||||||
outputs=[
|
outputs=[
|
||||||
img2img_gallery,
|
img2img_gallery,
|
||||||
|
@ -950,7 +923,6 @@ def create_ui():
|
||||||
(sampler_index, "Sampler"),
|
(sampler_index, "Sampler"),
|
||||||
(restore_faces, "Face restoration"),
|
(restore_faces, "Face restoration"),
|
||||||
(cfg_scale, "CFG scale"),
|
(cfg_scale, "CFG scale"),
|
||||||
(image_cfg_scale, "Image CFG scale"),
|
|
||||||
(seed, "Seed"),
|
(seed, "Seed"),
|
||||||
(width, "Size-1"),
|
(width, "Size-1"),
|
||||||
(height, "Size-2"),
|
(height, "Size-2"),
|
||||||
|
@ -965,9 +937,6 @@ def create_ui():
|
||||||
]
|
]
|
||||||
parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields)
|
parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields)
|
||||||
parameters_copypaste.add_paste_fields("inpaint", init_img_with_mask, img2img_paste_fields)
|
parameters_copypaste.add_paste_fields("inpaint", init_img_with_mask, img2img_paste_fields)
|
||||||
parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(
|
|
||||||
paste_button=img2img_paste, tabname="img2img", source_text_component=img2img_prompt, source_image_component=None, override_settings_component=override_settings,
|
|
||||||
))
|
|
||||||
|
|
||||||
modules.scripts.scripts_current = None
|
modules.scripts.scripts_current = None
|
||||||
|
|
||||||
|
@ -985,11 +954,7 @@ def create_ui():
|
||||||
html2 = gr.HTML()
|
html2 = gr.HTML()
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
buttons = parameters_copypaste.create_buttons(["txt2img", "img2img", "inpaint", "extras"])
|
buttons = parameters_copypaste.create_buttons(["txt2img", "img2img", "inpaint", "extras"])
|
||||||
|
parameters_copypaste.bind_buttons(buttons, image, generation_info)
|
||||||
for tabname, button in buttons.items():
|
|
||||||
parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(
|
|
||||||
paste_button=button, tabname=tabname, source_text_component=generation_info, source_image_component=image,
|
|
||||||
))
|
|
||||||
|
|
||||||
image.change(
|
image.change(
|
||||||
fn=wrap_gradio_call(modules.extras.run_pnginfo),
|
fn=wrap_gradio_call(modules.extras.run_pnginfo),
|
||||||
|
@ -1398,7 +1363,6 @@ def create_ui():
|
||||||
|
|
||||||
components = []
|
components = []
|
||||||
component_dict = {}
|
component_dict = {}
|
||||||
shared.settings_components = component_dict
|
|
||||||
|
|
||||||
script_callbacks.ui_settings_callback()
|
script_callbacks.ui_settings_callback()
|
||||||
opts.reorder()
|
opts.reorder()
|
||||||
|
@ -1565,7 +1529,8 @@ def create_ui():
|
||||||
component = create_setting_component(k, is_quicksettings=True)
|
component = create_setting_component(k, is_quicksettings=True)
|
||||||
component_dict[k] = component
|
component_dict[k] = component
|
||||||
|
|
||||||
parameters_copypaste.connect_paste_params_buttons()
|
parameters_copypaste.integrate_settings_paste_fields(component_dict)
|
||||||
|
parameters_copypaste.run_bind()
|
||||||
|
|
||||||
with gr.Tabs(elem_id="tabs") as tabs:
|
with gr.Tabs(elem_id="tabs") as tabs:
|
||||||
for interface, label, ifid in interfaces:
|
for interface, label, ifid in interfaces:
|
||||||
|
@ -1595,20 +1560,6 @@ def create_ui():
|
||||||
outputs=[component, text_settings],
|
outputs=[component, text_settings],
|
||||||
)
|
)
|
||||||
|
|
||||||
text_settings.change(
|
|
||||||
fn=lambda: gr.update(visible=shared.sd_model and shared.sd_model.cond_stage_key == "edit"),
|
|
||||||
inputs=[],
|
|
||||||
outputs=[image_cfg_scale],
|
|
||||||
)
|
|
||||||
|
|
||||||
button_set_checkpoint = gr.Button('Change checkpoint', elem_id='change_checkpoint', visible=False)
|
|
||||||
button_set_checkpoint.click(
|
|
||||||
fn=lambda value, _: run_settings_single(value, key='sd_model_checkpoint'),
|
|
||||||
_js="function(v){ var res = desiredCheckpointName; desiredCheckpointName = ''; return [res || v, null]; }",
|
|
||||||
inputs=[component_dict['sd_model_checkpoint'], dummy_component],
|
|
||||||
outputs=[component_dict['sd_model_checkpoint'], text_settings],
|
|
||||||
)
|
|
||||||
|
|
||||||
component_keys = [k for k in opts.data_labels.keys() if k in component_dict]
|
component_keys = [k for k in opts.data_labels.keys() if k in component_dict]
|
||||||
|
|
||||||
def get_settings_values():
|
def get_settings_values():
|
||||||
|
|
|
@ -198,9 +198,5 @@ Requested path was: {f}
|
||||||
html_info = gr.HTML(elem_id=f'html_info_{tabname}')
|
html_info = gr.HTML(elem_id=f'html_info_{tabname}')
|
||||||
html_log = gr.HTML(elem_id=f'html_log_{tabname}')
|
html_log = gr.HTML(elem_id=f'html_log_{tabname}')
|
||||||
|
|
||||||
for paste_tabname, paste_button in buttons.items():
|
parameters_copypaste.bind_buttons(buttons, result_gallery, "txt2img" if tabname == "txt2img" else None)
|
||||||
parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(
|
|
||||||
paste_button=paste_button, tabname=paste_tabname, source_tabname="txt2img" if tabname == "txt2img" else None, source_image_component=result_gallery
|
|
||||||
))
|
|
||||||
|
|
||||||
return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log
|
return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log
|
||||||
|
|
|
@ -1,7 +1,4 @@
|
||||||
import glob
|
|
||||||
import os.path
|
import os.path
|
||||||
import urllib.parse
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from modules import shared
|
from modules import shared
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
@ -11,32 +8,12 @@ import html
|
||||||
from modules.generation_parameters_copypaste import image_from_url_text
|
from modules.generation_parameters_copypaste import image_from_url_text
|
||||||
|
|
||||||
extra_pages = []
|
extra_pages = []
|
||||||
allowed_dirs = set()
|
|
||||||
|
|
||||||
|
|
||||||
def register_page(page):
|
def register_page(page):
|
||||||
"""registers extra networks page for the UI; recommend doing it in on_before_ui() callback for extensions"""
|
"""registers extra networks page for the UI; recommend doing it in on_before_ui() callback for extensions"""
|
||||||
|
|
||||||
extra_pages.append(page)
|
extra_pages.append(page)
|
||||||
allowed_dirs.clear()
|
|
||||||
allowed_dirs.update(set(sum([x.allowed_directories_for_previews() for x in extra_pages], [])))
|
|
||||||
|
|
||||||
|
|
||||||
def add_pages_to_demo(app):
|
|
||||||
def fetch_file(filename: str = ""):
|
|
||||||
from starlette.responses import FileResponse
|
|
||||||
|
|
||||||
if not any([Path(x).absolute() in Path(filename).absolute().parents for x in allowed_dirs]):
|
|
||||||
raise ValueError(f"File cannot be fetched: {filename}. Must be in one of directories registered by extra pages.")
|
|
||||||
|
|
||||||
ext = os.path.splitext(filename)[1].lower()
|
|
||||||
if ext not in (".png", ".jpg"):
|
|
||||||
raise ValueError(f"File cannot be fetched: {filename}. Only png and jpg.")
|
|
||||||
|
|
||||||
# would profit from returning 304
|
|
||||||
return FileResponse(filename, headers={"Accept-Ranges": "bytes"})
|
|
||||||
|
|
||||||
app.add_api_route("/sd_extra_networks/thumb", fetch_file, methods=["GET"])
|
|
||||||
|
|
||||||
|
|
||||||
class ExtraNetworksPage:
|
class ExtraNetworksPage:
|
||||||
|
@ -49,44 +26,10 @@ class ExtraNetworksPage:
|
||||||
def refresh(self):
|
def refresh(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def link_preview(self, filename):
|
|
||||||
return "./sd_extra_networks/thumb?filename=" + urllib.parse.quote(filename.replace('\\', '/')) + "&mtime=" + str(os.path.getmtime(filename))
|
|
||||||
|
|
||||||
def search_terms_from_path(self, filename, possible_directories=None):
|
|
||||||
abspath = os.path.abspath(filename)
|
|
||||||
|
|
||||||
for parentdir in (possible_directories if possible_directories is not None else self.allowed_directories_for_previews()):
|
|
||||||
parentdir = os.path.abspath(parentdir)
|
|
||||||
if abspath.startswith(parentdir):
|
|
||||||
return abspath[len(parentdir):].replace('\\', '/')
|
|
||||||
|
|
||||||
return ""
|
|
||||||
|
|
||||||
def create_html(self, tabname):
|
def create_html(self, tabname):
|
||||||
view = shared.opts.extra_networks_default_view
|
view = shared.opts.extra_networks_default_view
|
||||||
items_html = ''
|
items_html = ''
|
||||||
|
|
||||||
subdirs = {}
|
|
||||||
for parentdir in [os.path.abspath(x) for x in self.allowed_directories_for_previews()]:
|
|
||||||
for x in glob.glob(os.path.join(parentdir, '**/*'), recursive=True):
|
|
||||||
if not os.path.isdir(x):
|
|
||||||
continue
|
|
||||||
|
|
||||||
subdir = os.path.abspath(x)[len(parentdir):].replace("\\", "/")
|
|
||||||
while subdir.startswith("/"):
|
|
||||||
subdir = subdir[1:]
|
|
||||||
|
|
||||||
subdirs[subdir] = 1
|
|
||||||
|
|
||||||
if subdirs:
|
|
||||||
subdirs = {"": 1, **subdirs}
|
|
||||||
|
|
||||||
subdirs_html = "".join([f"""
|
|
||||||
<button class='gr-button gr-button-lg gr-button-secondary{" search-all" if subdir=="" else ""}' onclick='extraNetworksSearchButton("{tabname}_extra_tabs", event)'>
|
|
||||||
{html.escape(subdir if subdir!="" else "all")}
|
|
||||||
</button>
|
|
||||||
""" for subdir in subdirs])
|
|
||||||
|
|
||||||
for item in self.list_items():
|
for item in self.list_items():
|
||||||
items_html += self.create_html_for_item(item, tabname)
|
items_html += self.create_html_for_item(item, tabname)
|
||||||
|
|
||||||
|
@ -95,9 +38,6 @@ class ExtraNetworksPage:
|
||||||
items_html = shared.html("extra-networks-no-cards.html").format(dirs=dirs)
|
items_html = shared.html("extra-networks-no-cards.html").format(dirs=dirs)
|
||||||
|
|
||||||
res = f"""
|
res = f"""
|
||||||
<div id='{tabname}_{self.name}_subdirs' class='extra-network-subdirs extra-network-subdirs-{view}'>
|
|
||||||
{subdirs_html}
|
|
||||||
</div>
|
|
||||||
<div id='{tabname}_{self.name}_cards' class='extra-network-{view}'>
|
<div id='{tabname}_{self.name}_cards' class='extra-network-{view}'>
|
||||||
{items_html}
|
{items_html}
|
||||||
</div>
|
</div>
|
||||||
|
@ -114,19 +54,14 @@ class ExtraNetworksPage:
|
||||||
def create_html_for_item(self, item, tabname):
|
def create_html_for_item(self, item, tabname):
|
||||||
preview = item.get("preview", None)
|
preview = item.get("preview", None)
|
||||||
|
|
||||||
onclick = item.get("onclick", None)
|
|
||||||
if onclick is None:
|
|
||||||
onclick = '"' + html.escape(f"""return cardClicked({json.dumps(tabname)}, {item["prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"'
|
|
||||||
|
|
||||||
args = {
|
args = {
|
||||||
"preview_html": "style='background-image: url(\"" + html.escape(preview) + "\")'" if preview else '',
|
"preview_html": "style='background-image: url(\"" + html.escape(preview) + "\")'" if preview else '',
|
||||||
"prompt": item.get("prompt", None),
|
"prompt": item["prompt"],
|
||||||
"tabname": json.dumps(tabname),
|
"tabname": json.dumps(tabname),
|
||||||
"local_preview": json.dumps(item["local_preview"]),
|
"local_preview": json.dumps(item["local_preview"]),
|
||||||
"name": item["name"],
|
"name": item["name"],
|
||||||
"card_clicked": onclick,
|
"card_clicked": '"' + html.escape(f"""return cardClicked({json.dumps(tabname)}, {item["prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"',
|
||||||
"save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {json.dumps(tabname)}, {json.dumps(item["local_preview"])})""") + '"',
|
"save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {json.dumps(tabname)}, {json.dumps(item["local_preview"])})""") + '"',
|
||||||
"search_term": item.get("search_term", ""),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return self.card_page.format(**args)
|
return self.card_page.format(**args)
|
||||||
|
@ -208,7 +143,7 @@ def path_is_parent(parent_path, child_path):
|
||||||
parent_path = os.path.abspath(parent_path)
|
parent_path = os.path.abspath(parent_path)
|
||||||
child_path = os.path.abspath(child_path)
|
child_path = os.path.abspath(child_path)
|
||||||
|
|
||||||
return child_path.startswith(parent_path)
|
return os.path.commonpath([parent_path]) == os.path.commonpath([parent_path, child_path])
|
||||||
|
|
||||||
|
|
||||||
def setup_ui(ui, gallery):
|
def setup_ui(ui, gallery):
|
||||||
|
@ -238,8 +173,7 @@ def setup_ui(ui, gallery):
|
||||||
|
|
||||||
ui.button_save_preview.click(
|
ui.button_save_preview.click(
|
||||||
fn=save_preview,
|
fn=save_preview,
|
||||||
_js="function(x, y, z){return [selected_gallery_index(), y, z]}",
|
_js="function(x, y, z){console.log(x, y, z); return [selected_gallery_index(), y, z]}",
|
||||||
inputs=[ui.preview_target_filename, gallery, ui.preview_target_filename],
|
inputs=[ui.preview_target_filename, gallery, ui.preview_target_filename],
|
||||||
outputs=[*ui.pages]
|
outputs=[*ui.pages]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -1,39 +0,0 @@
|
||||||
import html
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import urllib.parse
|
|
||||||
|
|
||||||
from modules import shared, ui_extra_networks, sd_models
|
|
||||||
|
|
||||||
|
|
||||||
class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__('Checkpoints')
|
|
||||||
|
|
||||||
def refresh(self):
|
|
||||||
shared.refresh_checkpoints()
|
|
||||||
|
|
||||||
def list_items(self):
|
|
||||||
checkpoint: sd_models.CheckpointInfo
|
|
||||||
for name, checkpoint in sd_models.checkpoints_list.items():
|
|
||||||
path, ext = os.path.splitext(checkpoint.filename)
|
|
||||||
previews = [path + ".png", path + ".preview.png"]
|
|
||||||
|
|
||||||
preview = None
|
|
||||||
for file in previews:
|
|
||||||
if os.path.isfile(file):
|
|
||||||
preview = self.link_preview(file)
|
|
||||||
break
|
|
||||||
|
|
||||||
yield {
|
|
||||||
"name": checkpoint.name_for_extra,
|
|
||||||
"filename": path,
|
|
||||||
"preview": preview,
|
|
||||||
"search_term": self.search_terms_from_path(checkpoint.filename) + " " + (checkpoint.sha256 or ""),
|
|
||||||
"onclick": '"' + html.escape(f"""return selectCheckpoint({json.dumps(name)})""") + '"',
|
|
||||||
"local_preview": path + ".png",
|
|
||||||
}
|
|
||||||
|
|
||||||
def allowed_directories_for_previews(self):
|
|
||||||
return [v for v in [shared.cmd_opts.ckpt_dir, sd_models.model_path] if v is not None]
|
|
||||||
|
|
|
@ -19,14 +19,13 @@ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage):
|
||||||
preview = None
|
preview = None
|
||||||
for file in previews:
|
for file in previews:
|
||||||
if os.path.isfile(file):
|
if os.path.isfile(file):
|
||||||
preview = self.link_preview(file)
|
preview = "./file=" + file.replace('\\', '/') + "?mtime=" + str(os.path.getmtime(file))
|
||||||
break
|
break
|
||||||
|
|
||||||
yield {
|
yield {
|
||||||
"name": name,
|
"name": name,
|
||||||
"filename": path,
|
"filename": path,
|
||||||
"preview": preview,
|
"preview": preview,
|
||||||
"search_term": self.search_terms_from_path(path),
|
|
||||||
"prompt": json.dumps(f"<hypernet:{name}:") + " + opts.extra_networks_default_multiplier + " + json.dumps(">"),
|
"prompt": json.dumps(f"<hypernet:{name}:") + " + opts.extra_networks_default_multiplier + " + json.dumps(">"),
|
||||||
"local_preview": path + ".png",
|
"local_preview": path + ".png",
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,13 +19,12 @@ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage):
|
||||||
|
|
||||||
preview = None
|
preview = None
|
||||||
if os.path.isfile(preview_file):
|
if os.path.isfile(preview_file):
|
||||||
preview = self.link_preview(preview_file)
|
preview = "./file=" + preview_file.replace('\\', '/') + "?mtime=" + str(os.path.getmtime(preview_file))
|
||||||
|
|
||||||
yield {
|
yield {
|
||||||
"name": embedding.name,
|
"name": embedding.name,
|
||||||
"filename": embedding.filename,
|
"filename": embedding.filename,
|
||||||
"preview": preview,
|
"preview": preview,
|
||||||
"search_term": self.search_terms_from_path(embedding.filename),
|
|
||||||
"prompt": json.dumps(embedding.name),
|
"prompt": json.dumps(embedding.name),
|
||||||
"local_preview": path + ".preview.png",
|
"local_preview": path + ".preview.png",
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,7 +6,7 @@ from tqdm import trange
|
||||||
import modules.scripts as scripts
|
import modules.scripts as scripts
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
from modules import processing, shared, sd_samplers, prompt_parser, sd_samplers_common
|
from modules import processing, shared, sd_samplers, prompt_parser
|
||||||
from modules.processing import Processed
|
from modules.processing import Processed
|
||||||
from modules.shared import opts, cmd_opts, state
|
from modules.shared import opts, cmd_opts, state
|
||||||
|
|
||||||
|
@ -50,7 +50,7 @@ def find_noise_for_image(p, cond, uncond, cfg_scale, steps):
|
||||||
|
|
||||||
x = x + d * dt
|
x = x + d * dt
|
||||||
|
|
||||||
sd_samplers_common.store_latent(x)
|
sd_samplers.store_latent(x)
|
||||||
|
|
||||||
# This shouldn't be necessary, but solved some VRAM issues
|
# This shouldn't be necessary, but solved some VRAM issues
|
||||||
del x_in, sigma_in, cond_in, c_out, c_in, t,
|
del x_in, sigma_in, cond_in, c_out, c_in, t,
|
||||||
|
@ -104,7 +104,7 @@ def find_noise_for_image_sigma_adjustment(p, cond, uncond, cfg_scale, steps):
|
||||||
dt = sigmas[i] - sigmas[i - 1]
|
dt = sigmas[i] - sigmas[i - 1]
|
||||||
x = x + d * dt
|
x = x + d * dt
|
||||||
|
|
||||||
sd_samplers_common.store_latent(x)
|
sd_samplers.store_latent(x)
|
||||||
|
|
||||||
# This shouldn't be necessary, but solved some VRAM issues
|
# This shouldn't be necessary, but solved some VRAM issues
|
||||||
del x_in, sigma_in, cond_in, c_out, c_in, t,
|
del x_in, sigma_in, cond_in, c_out, c_in, t,
|
||||||
|
|
|
@ -44,34 +44,16 @@ class Script(scripts.Script):
|
||||||
def title(self):
|
def title(self):
|
||||||
return "Prompt matrix"
|
return "Prompt matrix"
|
||||||
|
|
||||||
def ui(self, is_img2img):
|
def ui(self, is_img2img):
|
||||||
gr.HTML('<br />')
|
put_at_start = gr.Checkbox(label='Put variable parts at start of prompt', value=False, elem_id=self.elem_id("put_at_start"))
|
||||||
with gr.Row():
|
different_seeds = gr.Checkbox(label='Use different seed for each picture', value=False, elem_id=self.elem_id("different_seeds"))
|
||||||
with gr.Column():
|
|
||||||
put_at_start = gr.Checkbox(label='Put variable parts at start of prompt', value=False, elem_id=self.elem_id("put_at_start"))
|
|
||||||
different_seeds = gr.Checkbox(label='Use different seed for each picture', value=False, elem_id=self.elem_id("different_seeds"))
|
|
||||||
with gr.Column():
|
|
||||||
prompt_type = gr.Radio(["positive", "negative"], label="Select prompt", elem_id=self.elem_id("prompt_type"), value="positive")
|
|
||||||
variations_delimiter = gr.Radio(["comma", "space"], label="Select joining char", elem_id=self.elem_id("variations_delimiter"), value="comma")
|
|
||||||
with gr.Column():
|
|
||||||
margin_size = gr.Slider(label="Grid margins (px)", min=0, max=500, value=0, step=2, elem_id=self.elem_id("margin_size"))
|
|
||||||
|
|
||||||
return [put_at_start, different_seeds, prompt_type, variations_delimiter, margin_size]
|
return [put_at_start, different_seeds]
|
||||||
|
|
||||||
def run(self, p, put_at_start, different_seeds, prompt_type, variations_delimiter, margin_size):
|
def run(self, p, put_at_start, different_seeds):
|
||||||
modules.processing.fix_seed(p)
|
modules.processing.fix_seed(p)
|
||||||
# Raise error if promp type is not positive or negative
|
|
||||||
if prompt_type not in ["positive", "negative"]:
|
|
||||||
raise ValueError(f"Unknown prompt type {prompt_type}")
|
|
||||||
# Raise error if variations delimiter is not comma or space
|
|
||||||
if variations_delimiter not in ["comma", "space"]:
|
|
||||||
raise ValueError(f"Unknown variations delimiter {variations_delimiter}")
|
|
||||||
|
|
||||||
prompt = p.prompt if prompt_type == "positive" else p.negative_prompt
|
original_prompt = p.prompt[0] if type(p.prompt) == list else p.prompt
|
||||||
original_prompt = prompt[0] if type(prompt) == list else prompt
|
|
||||||
positive_prompt = p.prompt[0] if type(p.prompt) == list else p.prompt
|
|
||||||
|
|
||||||
delimiter = ", " if variations_delimiter == "comma" else " "
|
|
||||||
|
|
||||||
all_prompts = []
|
all_prompts = []
|
||||||
prompt_matrix_parts = original_prompt.split("|")
|
prompt_matrix_parts = original_prompt.split("|")
|
||||||
|
@ -84,23 +66,20 @@ class Script(scripts.Script):
|
||||||
else:
|
else:
|
||||||
selected_prompts = [prompt_matrix_parts[0]] + selected_prompts
|
selected_prompts = [prompt_matrix_parts[0]] + selected_prompts
|
||||||
|
|
||||||
all_prompts.append(delimiter.join(selected_prompts))
|
all_prompts.append(", ".join(selected_prompts))
|
||||||
|
|
||||||
p.n_iter = math.ceil(len(all_prompts) / p.batch_size)
|
p.n_iter = math.ceil(len(all_prompts) / p.batch_size)
|
||||||
p.do_not_save_grid = True
|
p.do_not_save_grid = True
|
||||||
|
|
||||||
print(f"Prompt matrix will create {len(all_prompts)} images using a total of {p.n_iter} batches.")
|
print(f"Prompt matrix will create {len(all_prompts)} images using a total of {p.n_iter} batches.")
|
||||||
|
|
||||||
if prompt_type == "positive":
|
p.prompt = all_prompts
|
||||||
p.prompt = all_prompts
|
|
||||||
else:
|
|
||||||
p.negative_prompt = all_prompts
|
|
||||||
p.seed = [p.seed + (i if different_seeds else 0) for i in range(len(all_prompts))]
|
p.seed = [p.seed + (i if different_seeds else 0) for i in range(len(all_prompts))]
|
||||||
p.prompt_for_display = positive_prompt
|
p.prompt_for_display = original_prompt
|
||||||
processed = process_images(p)
|
processed = process_images(p)
|
||||||
|
|
||||||
grid = images.image_grid(processed.images, p.batch_size, rows=1 << ((len(prompt_matrix_parts) - 1) // 2))
|
grid = images.image_grid(processed.images, p.batch_size, rows=1 << ((len(prompt_matrix_parts) - 1) // 2))
|
||||||
grid = images.draw_prompt_matrix(grid, p.width, p.height, prompt_matrix_parts, margin_size)
|
grid = images.draw_prompt_matrix(grid, p.width, p.height, prompt_matrix_parts)
|
||||||
processed.images.insert(0, grid)
|
processed.images.insert(0, grid)
|
||||||
processed.index_of_first_image = 1
|
processed.index_of_first_image = 1
|
||||||
processed.infotexts.insert(0, processed.infotexts[0])
|
processed.infotexts.insert(0, processed.infotexts[0])
|
||||||
|
|
|
@ -205,7 +205,7 @@ axis_options = [
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def draw_xyz_grid(p, xs, ys, zs, x_labels, y_labels, z_labels, cell, draw_legend, include_lone_images, include_sub_grids, first_axes_processed, second_axes_processed, margin_size):
|
def draw_xyz_grid(p, xs, ys, zs, x_labels, y_labels, z_labels, cell, draw_legend, include_lone_images, include_sub_grids, first_axes_processed, second_axes_processed):
|
||||||
hor_texts = [[images.GridAnnotation(x)] for x in x_labels]
|
hor_texts = [[images.GridAnnotation(x)] for x in x_labels]
|
||||||
ver_texts = [[images.GridAnnotation(y)] for y in y_labels]
|
ver_texts = [[images.GridAnnotation(y)] for y in y_labels]
|
||||||
title_texts = [[images.GridAnnotation(z)] for z in z_labels]
|
title_texts = [[images.GridAnnotation(z)] for z in z_labels]
|
||||||
|
@ -286,24 +286,23 @@ def draw_xyz_grid(p, xs, ys, zs, x_labels, y_labels, z_labels, cell, draw_legend
|
||||||
print("Unexpected error: draw_xyz_grid failed to return even a single processed image")
|
print("Unexpected error: draw_xyz_grid failed to return even a single processed image")
|
||||||
return Processed(p, [])
|
return Processed(p, [])
|
||||||
|
|
||||||
sub_grids = [None] * len(zs)
|
grids = [None] * len(zs)
|
||||||
for i in range(len(zs)):
|
for i in range(len(zs)):
|
||||||
start_index = i * len(xs) * len(ys)
|
start_index = i * len(xs) * len(ys)
|
||||||
end_index = start_index + len(xs) * len(ys)
|
end_index = start_index + len(xs) * len(ys)
|
||||||
grid = images.image_grid(image_cache[start_index:end_index], rows=len(ys))
|
grid = images.image_grid(image_cache[start_index:end_index], rows=len(ys))
|
||||||
if draw_legend:
|
if draw_legend:
|
||||||
grid = images.draw_grid_annotations(grid, cell_size[0], cell_size[1], hor_texts, ver_texts, margin_size)
|
grid = images.draw_grid_annotations(grid, cell_size[0], cell_size[1], hor_texts, ver_texts)
|
||||||
sub_grids[i] = grid
|
|
||||||
|
grids[i] = grid
|
||||||
if include_sub_grids and len(zs) > 1:
|
if include_sub_grids and len(zs) > 1:
|
||||||
processed_result.images.insert(i+1, grid)
|
processed_result.images.insert(i+1, grid)
|
||||||
|
|
||||||
sub_grid_size = sub_grids[0].size
|
original_grid_size = grids[0].size
|
||||||
z_grid = images.image_grid(sub_grids, rows=1)
|
grids = images.image_grid(grids, rows=1)
|
||||||
if draw_legend:
|
processed_result.images[0] = images.draw_grid_annotations(grids, original_grid_size[0], original_grid_size[1], title_texts, [[images.GridAnnotation()]])
|
||||||
z_grid = images.draw_grid_annotations(z_grid, sub_grid_size[0], sub_grid_size[1], title_texts, [[images.GridAnnotation()]])
|
|
||||||
processed_result.images[0] = z_grid
|
|
||||||
|
|
||||||
return processed_result, sub_grids
|
return processed_result
|
||||||
|
|
||||||
|
|
||||||
class SharedSettingsStackHelper(object):
|
class SharedSettingsStackHelper(object):
|
||||||
|
@ -351,16 +350,10 @@ class Script(scripts.Script):
|
||||||
fill_z_button = ToolButton(value=fill_values_symbol, elem_id="xyz_grid_fill_z_tool_button", visible=False)
|
fill_z_button = ToolButton(value=fill_values_symbol, elem_id="xyz_grid_fill_z_tool_button", visible=False)
|
||||||
|
|
||||||
with gr.Row(variant="compact", elem_id="axis_options"):
|
with gr.Row(variant="compact", elem_id="axis_options"):
|
||||||
with gr.Column():
|
draw_legend = gr.Checkbox(label='Draw legend', value=True, elem_id=self.elem_id("draw_legend"))
|
||||||
draw_legend = gr.Checkbox(label='Draw legend', value=True, elem_id=self.elem_id("draw_legend"))
|
include_lone_images = gr.Checkbox(label='Include Sub Images', value=False, elem_id=self.elem_id("include_lone_images"))
|
||||||
no_fixed_seeds = gr.Checkbox(label='Keep -1 for seeds', value=False, elem_id=self.elem_id("no_fixed_seeds"))
|
include_sub_grids = gr.Checkbox(label='Include Sub Grids', value=False, elem_id=self.elem_id("include_sub_grids"))
|
||||||
with gr.Column():
|
no_fixed_seeds = gr.Checkbox(label='Keep -1 for seeds', value=False, elem_id=self.elem_id("no_fixed_seeds"))
|
||||||
include_lone_images = gr.Checkbox(label='Include Sub Images', value=False, elem_id=self.elem_id("include_lone_images"))
|
|
||||||
include_sub_grids = gr.Checkbox(label='Include Sub Grids', value=False, elem_id=self.elem_id("include_sub_grids"))
|
|
||||||
with gr.Column():
|
|
||||||
margin_size = gr.Slider(label="Grid margins (px)", min=0, max=500, value=0, step=2, elem_id=self.elem_id("margin_size"))
|
|
||||||
|
|
||||||
with gr.Row(variant="compact", elem_id="swap_axes"):
|
|
||||||
swap_xy_axes_button = gr.Button(value="Swap X/Y axes", elem_id="xy_grid_swap_axes_button")
|
swap_xy_axes_button = gr.Button(value="Swap X/Y axes", elem_id="xy_grid_swap_axes_button")
|
||||||
swap_yz_axes_button = gr.Button(value="Swap Y/Z axes", elem_id="yz_grid_swap_axes_button")
|
swap_yz_axes_button = gr.Button(value="Swap Y/Z axes", elem_id="yz_grid_swap_axes_button")
|
||||||
swap_xz_axes_button = gr.Button(value="Swap X/Z axes", elem_id="xz_grid_swap_axes_button")
|
swap_xz_axes_button = gr.Button(value="Swap X/Z axes", elem_id="xz_grid_swap_axes_button")
|
||||||
|
@ -399,9 +392,9 @@ class Script(scripts.Script):
|
||||||
(z_values, "Z Values"),
|
(z_values, "Z Values"),
|
||||||
)
|
)
|
||||||
|
|
||||||
return [x_type, x_values, y_type, y_values, z_type, z_values, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, margin_size]
|
return [x_type, x_values, y_type, y_values, z_type, z_values, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds]
|
||||||
|
|
||||||
def run(self, p, x_type, x_values, y_type, y_values, z_type, z_values, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, margin_size):
|
def run(self, p, x_type, x_values, y_type, y_values, z_type, z_values, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds):
|
||||||
if not no_fixed_seeds:
|
if not no_fixed_seeds:
|
||||||
modules.processing.fix_seed(p)
|
modules.processing.fix_seed(p)
|
||||||
|
|
||||||
|
@ -583,7 +576,7 @@ class Script(scripts.Script):
|
||||||
return res
|
return res
|
||||||
|
|
||||||
with SharedSettingsStackHelper():
|
with SharedSettingsStackHelper():
|
||||||
processed, sub_grids = draw_xyz_grid(
|
processed = draw_xyz_grid(
|
||||||
p,
|
p,
|
||||||
xs=xs,
|
xs=xs,
|
||||||
ys=ys,
|
ys=ys,
|
||||||
|
@ -596,14 +589,9 @@ class Script(scripts.Script):
|
||||||
include_lone_images=include_lone_images,
|
include_lone_images=include_lone_images,
|
||||||
include_sub_grids=include_sub_grids,
|
include_sub_grids=include_sub_grids,
|
||||||
first_axes_processed=first_axes_processed,
|
first_axes_processed=first_axes_processed,
|
||||||
second_axes_processed=second_axes_processed,
|
second_axes_processed=second_axes_processed
|
||||||
margin_size=margin_size
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if opts.grid_save and len(sub_grids) > 1:
|
|
||||||
for sub_grid in sub_grids:
|
|
||||||
images.save_image(sub_grid, p.outpath_grids, "xyz_grid", info=grid_infotext[0], extension=opts.grid_format, prompt=p.prompt, seed=processed.seed, grid=True, p=p)
|
|
||||||
|
|
||||||
if opts.grid_save:
|
if opts.grid_save:
|
||||||
images.save_image(processed.images[0], p.outpath_grids, "xyz_grid", info=grid_infotext[0], extension=opts.grid_format, prompt=p.prompt, seed=processed.seed, grid=True, p=p)
|
images.save_image(processed.images[0], p.outpath_grids, "xyz_grid", info=grid_infotext[0], extension=opts.grid_format, prompt=p.prompt, seed=processed.seed, grid=True, p=p)
|
||||||
|
|
||||||
|
|
|
@ -807,13 +807,7 @@ footer {
|
||||||
margin: 0.3em;
|
margin: 0.3em;
|
||||||
}
|
}
|
||||||
|
|
||||||
.extra-network-subdirs{
|
|
||||||
padding: 0.2em 0.35em;
|
|
||||||
}
|
|
||||||
|
|
||||||
.extra-network-subdirs button{
|
|
||||||
margin: 0 0.15em;
|
|
||||||
}
|
|
||||||
|
|
||||||
#txt2img_extra_networks .search, #img2img_extra_networks .search{
|
#txt2img_extra_networks .search, #img2img_extra_networks .search{
|
||||||
display: inline-block;
|
display: inline-block;
|
||||||
|
|
|
@ -3,6 +3,6 @@
|
||||||
set PYTHON=
|
set PYTHON=
|
||||||
set GIT=
|
set GIT=
|
||||||
set VENV_DIR=
|
set VENV_DIR=
|
||||||
set COMMANDLINE_ARGS=--skip-torch-cuda-test --precision full --no-half
|
set COMMANDLINE_ARGS=
|
||||||
|
|
||||||
call webui.bat
|
call webui.bat
|
||||||
|
|
16
webui.py
16
webui.py
|
@ -12,7 +12,7 @@ from packaging import version
|
||||||
import logging
|
import logging
|
||||||
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
|
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
|
||||||
|
|
||||||
from modules import import_hook, errors, extra_networks, ui_extra_networks_checkpoints
|
from modules import import_hook, errors, extra_networks
|
||||||
from modules import extra_networks_hypernet, ui_extra_networks_hypernets, ui_extra_networks_textual_inversion
|
from modules import extra_networks_hypernet, ui_extra_networks_hypernets, ui_extra_networks_textual_inversion
|
||||||
from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call
|
from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call
|
||||||
|
|
||||||
|
@ -52,9 +52,6 @@ else:
|
||||||
|
|
||||||
|
|
||||||
def check_versions():
|
def check_versions():
|
||||||
if shared.cmd_opts.skip_version_check:
|
|
||||||
return
|
|
||||||
|
|
||||||
expected_torch_version = "1.13.1"
|
expected_torch_version = "1.13.1"
|
||||||
|
|
||||||
if version.parse(torch.__version__) < version.parse(expected_torch_version):
|
if version.parse(torch.__version__) < version.parse(expected_torch_version):
|
||||||
|
@ -62,10 +59,7 @@ def check_versions():
|
||||||
You are running torch {torch.__version__}.
|
You are running torch {torch.__version__}.
|
||||||
The program is tested to work with torch {expected_torch_version}.
|
The program is tested to work with torch {expected_torch_version}.
|
||||||
To reinstall the desired version, run with commandline flag --reinstall-torch.
|
To reinstall the desired version, run with commandline flag --reinstall-torch.
|
||||||
Beware that this will cause a lot of large files to be downloaded, as well as
|
Beware that this will cause a lot of large files to be downloaded.
|
||||||
there are reports of issues with training tab on the latest version.
|
|
||||||
|
|
||||||
Use --skip-version-check commandline argument to disable this check.
|
|
||||||
""".strip())
|
""".strip())
|
||||||
|
|
||||||
expected_xformers_version = "0.0.16rc425"
|
expected_xformers_version = "0.0.16rc425"
|
||||||
|
@ -77,8 +71,6 @@ Use --skip-version-check commandline argument to disable this check.
|
||||||
You are running xformers {xformers.__version__}.
|
You are running xformers {xformers.__version__}.
|
||||||
The program is tested to work with xformers {expected_xformers_version}.
|
The program is tested to work with xformers {expected_xformers_version}.
|
||||||
To reinstall the desired version, run with commandline flag --reinstall-xformers.
|
To reinstall the desired version, run with commandline flag --reinstall-xformers.
|
||||||
|
|
||||||
Use --skip-version-check commandline argument to disable this check.
|
|
||||||
""".strip())
|
""".strip())
|
||||||
|
|
||||||
|
|
||||||
|
@ -127,7 +119,6 @@ def initialize():
|
||||||
ui_extra_networks.intialize()
|
ui_extra_networks.intialize()
|
||||||
ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion())
|
ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion())
|
||||||
ui_extra_networks.register_page(ui_extra_networks_hypernets.ExtraNetworksPageHypernetworks())
|
ui_extra_networks.register_page(ui_extra_networks_hypernets.ExtraNetworksPageHypernetworks())
|
||||||
ui_extra_networks.register_page(ui_extra_networks_checkpoints.ExtraNetworksPageCheckpoints())
|
|
||||||
|
|
||||||
extra_networks.initialize()
|
extra_networks.initialize()
|
||||||
extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet())
|
extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet())
|
||||||
|
@ -236,8 +227,6 @@ def webui():
|
||||||
if launch_api:
|
if launch_api:
|
||||||
create_api(app)
|
create_api(app)
|
||||||
|
|
||||||
ui_extra_networks.add_pages_to_demo(app)
|
|
||||||
|
|
||||||
modules.script_callbacks.app_started_callback(shared.demo, app)
|
modules.script_callbacks.app_started_callback(shared.demo, app)
|
||||||
|
|
||||||
wait_on_server(shared.demo)
|
wait_on_server(shared.demo)
|
||||||
|
@ -265,7 +254,6 @@ def webui():
|
||||||
ui_extra_networks.intialize()
|
ui_extra_networks.intialize()
|
||||||
ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion())
|
ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion())
|
||||||
ui_extra_networks.register_page(ui_extra_networks_hypernets.ExtraNetworksPageHypernetworks())
|
ui_extra_networks.register_page(ui_extra_networks_hypernets.ExtraNetworksPageHypernetworks())
|
||||||
ui_extra_networks.register_page(ui_extra_networks_checkpoints.ExtraNetworksPageCheckpoints())
|
|
||||||
|
|
||||||
extra_networks.initialize()
|
extra_networks.initialize()
|
||||||
extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet())
|
extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet())
|
||||||
|
|
Loading…
Reference in New Issue
Block a user