Merge remote-tracking branch 'upstream/master' into interrogate_include_ranks_in_output

This commit is contained in:
Greg Fuller 2022-10-12 12:44:41 -07:00
commit fb3cefb348
20 changed files with 707 additions and 226 deletions

View File

@ -1045,7 +1045,6 @@ Bakemono Zukushi,0.67051035,anime
Lucy Madox Brown,0.67032814,fineart Lucy Madox Brown,0.67032814,fineart
Paul Wonner,0.6700563,scribbles Paul Wonner,0.6700563,scribbles
Guido Borelli Da Caluso,0.66966087,digipa-high-impact Guido Borelli Da Caluso,0.66966087,digipa-high-impact
Guido Borelli da Caluso,0.66966087,digipa-high-impact
Emil Alzamora,0.5844039,nudity Emil Alzamora,0.5844039,nudity
Heinrich Brocksieper,0.64469147,fineart Heinrich Brocksieper,0.64469147,fineart
Dan Smith,0.669563,digipa-high-impact Dan Smith,0.669563,digipa-high-impact

1 artist score category
1045 Lucy Madox Brown 0.67032814 fineart
1046 Paul Wonner 0.6700563 scribbles
1047 Guido Borelli Da Caluso 0.66966087 digipa-high-impact
Guido Borelli da Caluso 0.66966087 digipa-high-impact
1048 Emil Alzamora 0.5844039 nudity
1049 Heinrich Brocksieper 0.64469147 fineart
1050 Dan Smith 0.669563 digipa-high-impact

View File

@ -3,9 +3,9 @@ channels:
- pytorch - pytorch
- defaults - defaults
dependencies: dependencies:
- python=3.8.5 - python=3.10
- pip=20.3 - pip=22.2.2
- cudatoolkit=11.3 - cudatoolkit=11.3
- pytorch=1.11.0 - pytorch=1.12.1
- torchvision=0.12.0 - torchvision=0.13.1
- numpy=1.19.2 - numpy=1.23.1

View File

@ -25,6 +25,7 @@ addEventListener('keydown', (event) => {
} else { } else {
end = target.value.slice(selectionEnd + 1).indexOf(")") + 1; end = target.value.slice(selectionEnd + 1).indexOf(")") + 1;
weight = parseFloat(target.value.slice(selectionEnd + 1, selectionEnd + 1 + end)); weight = parseFloat(target.value.slice(selectionEnd + 1, selectionEnd + 1 + end));
if (isNaN(weight)) return;
if (event.key == minus) weight -= 0.1; if (event.key == minus) weight -= 0.1;
if (event.key == plus) weight += 0.1; if (event.key == plus) weight += 0.1;

View File

@ -80,7 +80,10 @@ titles = {
"Scale latent": "Uscale the image in latent space. Alternative is to produce the full image from latent representation, upscale that, and then move it back to latent space.", "Scale latent": "Uscale the image in latent space. Alternative is to produce the full image from latent representation, upscale that, and then move it back to latent space.",
"Eta noise seed delta": "If this values is non-zero, it will be added to seed and used to initialize RNG for noises when using samplers with Eta. You can use this to produce even more variation of images, or you can use this to match images of other software if you know what you are doing.", "Eta noise seed delta": "If this values is non-zero, it will be added to seed and used to initialize RNG for noises when using samplers with Eta. You can use this to produce even more variation of images, or you can use this to match images of other software if you know what you are doing.",
"Do not add watermark to images": "If this option is enabled, watermark will not be added to created images. Warning: if you do not add watermark, you may be bevaing in an unethical manner.", "Do not add watermark to images": "If this option is enabled, watermark will not be added to created images. Warning: if you do not add watermark, you may be behaving in an unethical manner.",
"Filename word regex": "This regular expression will be used extract words from filename, and they will be joined using the option below into label text used for training. Leave empty to keep filename text as it is.",
"Filename join string": "This string will be used to hoin split words into a single line if the option above is enabled.",
} }

View File

@ -101,7 +101,8 @@ function create_tab_index_args(tabId, args){
} }
function get_extras_tab_index(){ function get_extras_tab_index(){
return create_tab_index_args('mode_extras', arguments) const [,,...args] = [...arguments]
return [get_tab_index('mode_extras'), get_tab_index('extras_resize_mode'), ...args]
} }
function create_submit_args(args){ function create_submit_args(args){

View File

@ -1,20 +1,98 @@
import os.path import os.path
from concurrent.futures import ProcessPoolExecutor from concurrent.futures import ProcessPoolExecutor
from multiprocessing import get_context import multiprocessing
import time
import re
re_special = re.compile(r'([\\()])')
def get_deepbooru_tags(pil_image):
"""
This method is for running only one image at a time for simple use. Used to the img2img interrogate.
"""
from modules import shared # prevents circular reference
try:
create_deepbooru_process(shared.opts.interrogate_deepbooru_score_threshold, create_deepbooru_opts())
return get_tags_from_process(pil_image)
finally:
release_process()
def _load_tf_and_return_tags(pil_image, threshold, include_ranks): def create_deepbooru_opts():
from modules import shared
return {
"use_spaces": shared.opts.deepbooru_use_spaces,
"use_escape": shared.opts.deepbooru_escape,
"alpha_sort": shared.opts.deepbooru_sort_alpha,
}
def deepbooru_process(queue, deepbooru_process_return, threshold, deepbooru_opts):
model, tags = get_deepbooru_tags_model()
while True: # while process is running, keep monitoring queue for new image
pil_image = queue.get()
if pil_image == "QUIT":
break
else:
deepbooru_process_return["value"] = get_deepbooru_tags_from_model(model, tags, pil_image, threshold, deepbooru_opts)
def create_deepbooru_process(threshold, deepbooru_opts):
"""
Creates deepbooru process. A queue is created to send images into the process. This enables multiple images
to be processed in a row without reloading the model or creating a new process. To return the data, a shared
dictionary is created to hold the tags created. To wait for tags to be returned, a value of -1 is assigned
to the dictionary and the method adding the image to the queue should wait for this value to be updated with
the tags.
"""
from modules import shared # prevents circular reference
shared.deepbooru_process_manager = multiprocessing.Manager()
shared.deepbooru_process_queue = shared.deepbooru_process_manager.Queue()
shared.deepbooru_process_return = shared.deepbooru_process_manager.dict()
shared.deepbooru_process_return["value"] = -1
shared.deepbooru_process = multiprocessing.Process(target=deepbooru_process, args=(shared.deepbooru_process_queue, shared.deepbooru_process_return, threshold, deepbooru_opts))
shared.deepbooru_process.start()
def get_tags_from_process(image):
from modules import shared
shared.deepbooru_process_return["value"] = -1
shared.deepbooru_process_queue.put(image)
while shared.deepbooru_process_return["value"] == -1:
time.sleep(0.2)
caption = shared.deepbooru_process_return["value"]
shared.deepbooru_process_return["value"] = -1
return caption
def release_process():
"""
Stops the deepbooru process to return used memory
"""
from modules import shared # prevents circular reference
shared.deepbooru_process_queue.put("QUIT")
shared.deepbooru_process.join()
shared.deepbooru_process_queue = None
shared.deepbooru_process = None
shared.deepbooru_process_return = None
shared.deepbooru_process_manager = None
def get_deepbooru_tags_model():
import deepdanbooru as dd import deepdanbooru as dd
import tensorflow as tf import tensorflow as tf
import numpy as np import numpy as np
this_folder = os.path.dirname(__file__) this_folder = os.path.dirname(__file__)
model_path = os.path.abspath(os.path.join(this_folder, '..', 'models', 'deepbooru')) model_path = os.path.abspath(os.path.join(this_folder, '..', 'models', 'deepbooru'))
if not os.path.exists(os.path.join(model_path, 'project.json')): if not os.path.exists(os.path.join(model_path, 'project.json')):
# there is no point importing these every time # there is no point importing these every time
import zipfile import zipfile
from basicsr.utils.download_util import load_file_from_url from basicsr.utils.download_util import load_file_from_url
load_file_from_url(r"https://github.com/KichangKim/DeepDanbooru/releases/download/v3-20211112-sgd-e28/deepdanbooru-v3-20211112-sgd-e28.zip", load_file_from_url(
r"https://github.com/KichangKim/DeepDanbooru/releases/download/v3-20211112-sgd-e28/deepdanbooru-v3-20211112-sgd-e28.zip",
model_path) model_path)
with zipfile.ZipFile(os.path.join(model_path, "deepdanbooru-v3-20211112-sgd-e28.zip"), "r") as zip_ref: with zipfile.ZipFile(os.path.join(model_path, "deepdanbooru-v3-20211112-sgd-e28.zip"), "r") as zip_ref:
zip_ref.extractall(model_path) zip_ref.extractall(model_path)
@ -24,6 +102,17 @@ def _load_tf_and_return_tags(pil_image, threshold, include_ranks):
model = dd.project.load_model_from_project( model = dd.project.load_model_from_project(
model_path, compile_model=True model_path, compile_model=True
) )
return model, tags
def get_deepbooru_tags_from_model(model, tags, pil_image, threshold, deepbooru_opts):
import deepdanbooru as dd
import tensorflow as tf
import numpy as np
alpha_sort = deepbooru_opts['alpha_sort']
use_spaces = deepbooru_opts['use_spaces']
use_escape = deepbooru_opts['use_escape']
width = model.input_shape[2] width = model.input_shape[2]
height = model.input_shape[1] height = model.input_shape[1]
@ -46,32 +135,35 @@ def _load_tf_and_return_tags(pil_image, threshold, include_ranks):
for i, tag in enumerate(tags): for i, tag in enumerate(tags):
result_dict[tag] = y[i] result_dict[tag] = y[i]
result_tags_out = []
unsorted_tags_in_theshold = []
result_tags_print = [] result_tags_print = []
for tag in tags: for tag in tags:
if result_dict[tag] >= threshold: if result_dict[tag] >= threshold:
if tag.startswith("rating:"): if tag.startswith("rating:"):
continue continue
tag_formatted = tag.replace('_', ' ').replace(':', ' ') unsorted_tags_in_theshold.append((result_dict[tag], tag))
if include_ranks:
result_tags_out.append(f'({tag_formatted}:{result_dict[tag]})')
else:
result_tags_out.append(tag_formatted)
result_tags_print.append(f'{result_dict[tag]} {tag}') result_tags_print.append(f'{result_dict[tag]} {tag}')
# sort tags
result_tags_out = []
sort_ndx = 0
if alpha_sort:
sort_ndx = 1
# sort by reverse by likelihood and normal for alpha
unsorted_tags_in_theshold.sort(key=lambda y: y[sort_ndx], reverse=(not alpha_sort))
for weight, tag in unsorted_tags_in_theshold:
result_tags_out.append(tag)
print('\n'.join(sorted(result_tags_print, reverse=True))) print('\n'.join(sorted(result_tags_print, reverse=True)))
return ', '.join(result_tags_out) tags_text = ', '.join(result_tags_out)
if use_spaces:
tags_text = tags_text.replace('_', ' ')
def subprocess_init_no_cuda(): if use_escape:
import os tags_text = re.sub(re_special, r'\\\1', tags_text)
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
return tags_text.replace(':', ' ')
def get_deepbooru_tags(pil_image, threshold=0.5, include_ranks=False):
context = get_context('spawn')
with ProcessPoolExecutor(initializer=subprocess_init_no_cuda, mp_context=context) as executor:
f = executor.submit(_load_tf_and_return_tags, pil_image, threshold, include_ranks)
ret = f.result() # will rethrow any exceptions
return ret

View File

@ -1,3 +1,4 @@
import math
import os import os
import numpy as np import numpy as np
@ -19,7 +20,7 @@ import gradio as gr
cached_images = {} cached_images = {}
def run_extras(extras_mode, image, image_folder, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility): def run_extras(extras_mode, resize_mode, image, image_folder, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility):
devices.torch_gc() devices.torch_gc()
imageArr = [] imageArr = []
@ -67,8 +68,13 @@ def run_extras(extras_mode, image, image_folder, gfpgan_visibility, codeformer_v
info += f"CodeFormer w: {round(codeformer_weight, 2)}, CodeFormer visibility:{round(codeformer_visibility, 2)}\n" info += f"CodeFormer w: {round(codeformer_weight, 2)}, CodeFormer visibility:{round(codeformer_visibility, 2)}\n"
image = res image = res
if resize_mode == 1:
upscaling_resize = max(upscaling_resize_w/image.width, upscaling_resize_h/image.height)
crop_info = " (crop)" if upscaling_crop else ""
info += f"Resize to: {upscaling_resize_w:g}x{upscaling_resize_h:g}{crop_info}\n"
if upscaling_resize != 1.0: if upscaling_resize != 1.0:
def upscale(image, scaler_index, resize): def upscale(image, scaler_index, resize, mode, resize_w, resize_h, crop):
small = image.crop((image.width // 2, image.height // 2, image.width // 2 + 10, image.height // 2 + 10)) small = image.crop((image.width // 2, image.height // 2, image.width // 2 + 10, image.height // 2 + 10))
pixels = tuple(np.array(small).flatten().tolist()) pixels = tuple(np.array(small).flatten().tolist())
key = (resize, scaler_index, image.width, image.height, gfpgan_visibility, codeformer_visibility, codeformer_weight) + pixels key = (resize, scaler_index, image.width, image.height, gfpgan_visibility, codeformer_visibility, codeformer_weight) + pixels
@ -77,15 +83,19 @@ def run_extras(extras_mode, image, image_folder, gfpgan_visibility, codeformer_v
if c is None: if c is None:
upscaler = shared.sd_upscalers[scaler_index] upscaler = shared.sd_upscalers[scaler_index]
c = upscaler.scaler.upscale(image, resize, upscaler.data_path) c = upscaler.scaler.upscale(image, resize, upscaler.data_path)
if mode == 1 and crop:
cropped = Image.new("RGB", (resize_w, resize_h))
cropped.paste(c, box=(resize_w // 2 - c.width // 2, resize_h // 2 - c.height // 2))
c = cropped
cached_images[key] = c cached_images[key] = c
return c return c
info += f"Upscale: {round(upscaling_resize, 3)}, model:{shared.sd_upscalers[extras_upscaler_1].name}\n" info += f"Upscale: {round(upscaling_resize, 3)}, model:{shared.sd_upscalers[extras_upscaler_1].name}\n"
res = upscale(image, extras_upscaler_1, upscaling_resize) res = upscale(image, extras_upscaler_1, upscaling_resize, resize_mode, upscaling_resize_w, upscaling_resize_h, upscaling_crop)
if extras_upscaler_2 != 0 and extras_upscaler_2_visibility > 0: if extras_upscaler_2 != 0 and extras_upscaler_2_visibility > 0:
res2 = upscale(image, extras_upscaler_2, upscaling_resize) res2 = upscale(image, extras_upscaler_2, upscaling_resize, resize_mode, upscaling_resize_w, upscaling_resize_h, upscaling_crop)
info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {round(extras_upscaler_2_visibility, 3)}, model:{shared.sd_upscalers[extras_upscaler_2].name}\n" info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {round(extras_upscaler_2_visibility, 3)}, model:{shared.sd_upscalers[extras_upscaler_2].name}\n"
res = Image.blend(res, res2, extras_upscaler_2_visibility) res = Image.blend(res, res2, extras_upscaler_2_visibility)

View File

@ -14,7 +14,7 @@ import torch
from torch import einsum from torch import einsum
from einops import rearrange, repeat from einops import rearrange, repeat
import modules.textual_inversion.dataset import modules.textual_inversion.dataset
from modules.textual_inversion.learn_schedule import LearnSchedule from modules.textual_inversion.learn_schedule import LearnRateScheduler
class HypernetworkModule(torch.nn.Module): class HypernetworkModule(torch.nn.Module):
@ -120,6 +120,17 @@ def load_hypernetwork(filename):
shared.loaded_hypernetwork = None shared.loaded_hypernetwork = None
def find_closest_hypernetwork_name(search: str):
if not search:
return None
search = search.lower()
applicable = [name for name in shared.hypernetworks if search in name.lower()]
if not applicable:
return None
applicable = sorted(applicable, key=lambda name: len(name))
return applicable[0]
def apply_hypernetwork(hypernetwork, context, layer=None): def apply_hypernetwork(hypernetwork, context, layer=None):
hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None) hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
@ -164,7 +175,7 @@ def attention_CrossAttention_forward(self, x, context=None, mask=None):
def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, steps, create_image_every, save_hypernetwork_every, template_file, preview_image_prompt): def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, steps, create_image_every, save_hypernetwork_every, template_file, preview_image_prompt):
assert hypernetwork_name, 'embedding not selected' assert hypernetwork_name, 'hypernetwork not selected'
path = shared.hypernetworks.get(hypernetwork_name, None) path = shared.hypernetworks.get(hypernetwork_name, None)
shared.loaded_hypernetwork = Hypernetwork() shared.loaded_hypernetwork = Hypernetwork()
@ -212,31 +223,23 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
if ititial_step > steps: if ititial_step > steps:
return hypernetwork, filename return hypernetwork, filename
schedules = iter(LearnSchedule(learn_rate, steps, ititial_step)) scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
(learn_rate, end_step) = next(schedules) optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate)
print(f'Training at rate of {learn_rate} until step {end_step}')
optimizer = torch.optim.AdamW(weights, lr=learn_rate)
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step) pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
for i, (x, text, cond) in pbar: for i, entry in pbar:
hypernetwork.step = i + ititial_step hypernetwork.step = i + ititial_step
if hypernetwork.step > end_step: scheduler.apply(optimizer, hypernetwork.step)
try: if scheduler.finished:
(learn_rate, end_step) = next(schedules)
except Exception:
break break
tqdm.tqdm.write(f'Training at rate of {learn_rate} until step {end_step}')
for pg in optimizer.param_groups:
pg['lr'] = learn_rate
if shared.state.interrupted: if shared.state.interrupted:
break break
with torch.autocast("cuda"): with torch.autocast("cuda"):
cond = cond.to(devices.device) cond = entry.cond.to(devices.device)
x = x.to(devices.device) x = entry.latent.to(devices.device)
loss = shared.sd_model(x.unsqueeze(0), cond)[0] loss = shared.sd_model(x.unsqueeze(0), cond)[0]
del x del x
del cond del cond
@ -256,7 +259,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0: if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0:
last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png') last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png')
preview_text = text if preview_image_prompt == "" else preview_image_prompt preview_text = entry.cond_text if preview_image_prompt == "" else preview_image_prompt
optimizer.zero_grad() optimizer.zero_grad()
shared.sd_model.cond_stage_model.to(devices.device) shared.sd_model.cond_stage_model.to(devices.device)
@ -271,15 +274,15 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
) )
processed = processing.process_images(p) processed = processing.process_images(p)
image = processed.images[0] image = processed.images[0] if len(processed.images)>0 else None
if unload: if unload:
shared.sd_model.cond_stage_model.to(devices.cpu) shared.sd_model.cond_stage_model.to(devices.cpu)
shared.sd_model.first_stage_model.to(devices.cpu) shared.sd_model.first_stage_model.to(devices.cpu)
if image is not None:
shared.state.current_image = image shared.state.current_image = image
image.save(last_saved_image) image.save(last_saved_image)
last_saved_image += f", prompt: {preview_text}" last_saved_image += f", prompt: {preview_text}"
shared.state.job_no = hypernetwork.step shared.state.job_no = hypernetwork.step
@ -288,7 +291,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
<p> <p>
Loss: {losses.mean():.7f}<br/> Loss: {losses.mean():.7f}<br/>
Step: {hypernetwork.step}<br/> Step: {hypernetwork.step}<br/>
Last prompt: {html.escape(text)}<br/> Last prompt: {html.escape(entry.cond_text)}<br/>
Last saved embedding: {html.escape(last_saved_file)}<br/> Last saved embedding: {html.escape(last_saved_file)}<br/>
Last saved image: {html.escape(last_saved_image)}<br/> Last saved image: {html.escape(last_saved_image)}<br/>
</p> </p>

View File

@ -321,7 +321,17 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
fixes.append(fix[1]) fixes.append(fix[1])
self.hijack.fixes.append(fixes) self.hijack.fixes.append(fixes)
z1 = self.process_tokens([x[:75] for x in remade_batch_tokens], [x[:75] for x in batch_multipliers]) tokens = []
multipliers = []
for j in range(len(remade_batch_tokens)):
if len(remade_batch_tokens[j]) > 0:
tokens.append(remade_batch_tokens[j][:75])
multipliers.append(batch_multipliers[j][:75])
else:
tokens.append([self.wrapped.tokenizer.eos_token_id] * 75)
multipliers.append([1.0] * 75)
z1 = self.process_tokens(tokens, multipliers)
z = z1 if z is None else torch.cat((z, z1), axis=-2) z = z1 if z is None else torch.cat((z, z1), axis=-2)
remade_batch_tokens = rem_tokens remade_batch_tokens = rem_tokens

View File

@ -86,6 +86,7 @@ parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram
xformers_available = False xformers_available = False
config_filename = cmd_opts.ui_settings_file config_filename = cmd_opts.ui_settings_file
os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True)
hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir) hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir)
loaded_hypernetwork = None loaded_hypernetwork = None
@ -229,7 +230,10 @@ options_templates.update(options_section(('system', "System"), {
})) }))
options_templates.update(options_section(('training', "Training"), { options_templates.update(options_section(('training', "Training"), {
"unload_models_when_training": OptionInfo(False, "Unload VAE and CLIP form VRAM when training"), "unload_models_when_training": OptionInfo(False, "Unload VAE and CLIP from VRAM when training"),
"dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
"dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
"training_image_repeats_per_epoch": OptionInfo(100, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
})) }))
options_templates.update(options_section(('sd', "Stable Diffusion"), { options_templates.update(options_section(('sd', "Stable Diffusion"), {
@ -255,8 +259,10 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"),
"interrogate_clip_num_beams": OptionInfo(1, "Interrogate: num_beams for BLIP", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}), "interrogate_clip_num_beams": OptionInfo(1, "Interrogate: num_beams for BLIP", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}),
"interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum description length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}), "interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum description length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}),
"interrogate_clip_max_length": OptionInfo(48, "Interrogate: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}), "interrogate_clip_max_length": OptionInfo(48, "Interrogate: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}),
"interrogate_clip_dict_limit": OptionInfo(1500, "Interrogate: maximum number of lines in text file (0 = No limit)"),
"interrogate_deepbooru_score_threshold": OptionInfo(0.5, "Interrogate: deepbooru score threshold", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}), "interrogate_deepbooru_score_threshold": OptionInfo(0.5, "Interrogate: deepbooru score threshold", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}),
"deepbooru_sort_alpha": OptionInfo(True, "Interrogate: deepbooru sort alphabetically"),
"deepbooru_use_spaces": OptionInfo(False, "use spaces for tags in deepbooru"),
"deepbooru_escape": OptionInfo(True, "escape (\\) brackets in deepbooru (so they are used as literal brackets and not for emphasis)"),
})) }))
options_templates.update(options_section(('ui', "User interface"), { options_templates.update(options_section(('ui', "User interface"), {

View File

@ -11,11 +11,21 @@ import tqdm
from modules import devices, shared from modules import devices, shared
import re import re
re_tag = re.compile(r"[a-zA-Z][_\w\d()]+") re_numbers_at_start = re.compile(r"^[-\d]+\s*")
class DatasetEntry:
def __init__(self, filename=None, latent=None, filename_text=None):
self.filename = filename
self.latent = latent
self.filename_text = filename_text
self.cond = None
self.cond_text = None
class PersonalizedBase(Dataset): class PersonalizedBase(Dataset):
def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None, include_cond=False): def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None, include_cond=False):
re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex)>0 else None
self.placeholder_token = placeholder_token self.placeholder_token = placeholder_token
@ -42,9 +52,18 @@ class PersonalizedBase(Dataset):
except Exception: except Exception:
continue continue
text_filename = os.path.splitext(path)[0] + ".txt"
filename = os.path.basename(path) filename = os.path.basename(path)
filename_tokens = os.path.splitext(filename)[0]
filename_tokens = re_tag.findall(filename_tokens) if os.path.exists(text_filename):
with open(text_filename, "r", encoding="utf8") as file:
filename_text = file.read()
else:
filename_text = os.path.splitext(filename)[0]
filename_text = re.sub(re_numbers_at_start, '', filename_text)
if re_word:
tokens = re_word.findall(filename_text)
filename_text = (shared.opts.dataset_filename_join_string or "").join(tokens)
npimage = np.array(image).astype(np.uint8) npimage = np.array(image).astype(np.uint8)
npimage = (npimage / 127.5 - 1.0).astype(np.float32) npimage = (npimage / 127.5 - 1.0).astype(np.float32)
@ -55,13 +74,13 @@ class PersonalizedBase(Dataset):
init_latent = model.get_first_stage_encoding(model.encode_first_stage(torchdata.unsqueeze(dim=0))).squeeze() init_latent = model.get_first_stage_encoding(model.encode_first_stage(torchdata.unsqueeze(dim=0))).squeeze()
init_latent = init_latent.to(devices.cpu) init_latent = init_latent.to(devices.cpu)
if include_cond: entry = DatasetEntry(filename=path, filename_text=filename_text, latent=init_latent)
text = self.create_text(filename_tokens)
cond = cond_model([text]).to(devices.cpu)
else:
cond = None
self.dataset.append((init_latent, filename_tokens, cond)) if include_cond:
entry.cond_text = self.create_text(filename_text)
entry.cond = cond_model([entry.cond_text]).to(devices.cpu)
self.dataset.append(entry)
self.length = len(self.dataset) * repeats self.length = len(self.dataset) * repeats
@ -72,10 +91,10 @@ class PersonalizedBase(Dataset):
def shuffle(self): def shuffle(self):
self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0])] self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0])]
def create_text(self, filename_tokens): def create_text(self, filename_text):
text = random.choice(self.lines) text = random.choice(self.lines)
text = text.replace("[name]", self.placeholder_token) text = text.replace("[name]", self.placeholder_token)
text = text.replace("[filewords]", ' '.join(filename_tokens)) text = text.replace("[filewords]", filename_text)
return text return text
def __len__(self): def __len__(self):
@ -86,7 +105,9 @@ class PersonalizedBase(Dataset):
self.shuffle() self.shuffle()
index = self.indexes[i % len(self.indexes)] index = self.indexes[i % len(self.indexes)]
x, filename_tokens, cond = self.dataset[index] entry = self.dataset[index]
text = self.create_text(filename_tokens) if entry.cond is None:
return x, text, cond entry.cond_text = self.create_text(entry.filename_text)
return entry

View File

@ -0,0 +1,219 @@
import base64
import json
import numpy as np
import zlib
from PIL import Image, PngImagePlugin, ImageDraw, ImageFont
from fonts.ttf import Roboto
import torch
class EmbeddingEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, torch.Tensor):
return {'TORCHTENSOR': obj.cpu().detach().numpy().tolist()}
return json.JSONEncoder.default(self, obj)
class EmbeddingDecoder(json.JSONDecoder):
def __init__(self, *args, **kwargs):
json.JSONDecoder.__init__(self, object_hook=self.object_hook, *args, **kwargs)
def object_hook(self, d):
if 'TORCHTENSOR' in d:
return torch.from_numpy(np.array(d['TORCHTENSOR']))
return d
def embedding_to_b64(data):
d = json.dumps(data, cls=EmbeddingEncoder)
return base64.b64encode(d.encode())
def embedding_from_b64(data):
d = base64.b64decode(data)
return json.loads(d, cls=EmbeddingDecoder)
def lcg(m=2**32, a=1664525, c=1013904223, seed=0):
while True:
seed = (a * seed + c) % m
yield seed % 255
def xor_block(block):
g = lcg()
randblock = np.array([next(g) for _ in range(np.product(block.shape))]).astype(np.uint8).reshape(block.shape)
return np.bitwise_xor(block.astype(np.uint8), randblock & 0x0F)
def style_block(block, sequence):
im = Image.new('RGB', (block.shape[1], block.shape[0]))
draw = ImageDraw.Draw(im)
i = 0
for x in range(-6, im.size[0], 8):
for yi, y in enumerate(range(-6, im.size[1], 8)):
offset = 0
if yi % 2 == 0:
offset = 4
shade = sequence[i % len(sequence)]
i += 1
draw.ellipse((x+offset, y, x+6+offset, y+6), fill=(shade, shade, shade))
fg = np.array(im).astype(np.uint8) & 0xF0
return block ^ fg
def insert_image_data_embed(image, data):
d = 3
data_compressed = zlib.compress(json.dumps(data, cls=EmbeddingEncoder).encode(), level=9)
data_np_ = np.frombuffer(data_compressed, np.uint8).copy()
data_np_high = data_np_ >> 4
data_np_low = data_np_ & 0x0F
h = image.size[1]
next_size = data_np_low.shape[0] + (h-(data_np_low.shape[0] % h))
next_size = next_size + ((h*d)-(next_size % (h*d)))
data_np_low.resize(next_size)
data_np_low = data_np_low.reshape((h, -1, d))
data_np_high.resize(next_size)
data_np_high = data_np_high.reshape((h, -1, d))
edge_style = list(data['string_to_param'].values())[0].cpu().detach().numpy().tolist()[0][:1024]
edge_style = (np.abs(edge_style)/np.max(np.abs(edge_style))*255).astype(np.uint8)
data_np_low = style_block(data_np_low, sequence=edge_style)
data_np_low = xor_block(data_np_low)
data_np_high = style_block(data_np_high, sequence=edge_style[::-1])
data_np_high = xor_block(data_np_high)
im_low = Image.fromarray(data_np_low, mode='RGB')
im_high = Image.fromarray(data_np_high, mode='RGB')
background = Image.new('RGB', (image.size[0]+im_low.size[0]+im_high.size[0]+2, image.size[1]), (0, 0, 0))
background.paste(im_low, (0, 0))
background.paste(image, (im_low.size[0]+1, 0))
background.paste(im_high, (im_low.size[0]+1+image.size[0]+1, 0))
return background
def crop_black(img, tol=0):
mask = (img > tol).all(2)
mask0, mask1 = mask.any(0), mask.any(1)
col_start, col_end = mask0.argmax(), mask.shape[1]-mask0[::-1].argmax()
row_start, row_end = mask1.argmax(), mask.shape[0]-mask1[::-1].argmax()
return img[row_start:row_end, col_start:col_end]
def extract_image_data_embed(image):
d = 3
outarr = crop_black(np.array(image.convert('RGB').getdata()).reshape(image.size[1], image.size[0], d).astype(np.uint8)) & 0x0F
black_cols = np.where(np.sum(outarr, axis=(0, 2)) == 0)
if black_cols[0].shape[0] < 2:
print('No Image data blocks found.')
return None
data_block_lower = outarr[:, :black_cols[0].min(), :].astype(np.uint8)
data_block_upper = outarr[:, black_cols[0].max()+1:, :].astype(np.uint8)
data_block_lower = xor_block(data_block_lower)
data_block_upper = xor_block(data_block_upper)
data_block = (data_block_upper << 4) | (data_block_lower)
data_block = data_block.flatten().tobytes()
data = zlib.decompress(data_block)
return json.loads(data, cls=EmbeddingDecoder)
def caption_image_overlay(srcimage, title, footerLeft, footerMid, footerRight, textfont=None):
from math import cos
image = srcimage.copy()
if textfont is None:
try:
textfont = ImageFont.truetype(opts.font or Roboto, fontsize)
textfont = opts.font or Roboto
except Exception:
textfont = Roboto
factor = 1.5
gradient = Image.new('RGBA', (1, image.size[1]), color=(0, 0, 0, 0))
for y in range(image.size[1]):
mag = 1-cos(y/image.size[1]*factor)
mag = max(mag, 1-cos((image.size[1]-y)/image.size[1]*factor*1.1))
gradient.putpixel((0, y), (0, 0, 0, int(mag*255)))
image = Image.alpha_composite(image.convert('RGBA'), gradient.resize(image.size))
draw = ImageDraw.Draw(image)
fontsize = 32
font = ImageFont.truetype(textfont, fontsize)
padding = 10
_, _, w, h = draw.textbbox((0, 0), title, font=font)
fontsize = min(int(fontsize * (((image.size[0]*0.75)-(padding*4))/w)), 72)
font = ImageFont.truetype(textfont, fontsize)
_, _, w, h = draw.textbbox((0, 0), title, font=font)
draw.text((padding, padding), title, anchor='lt', font=font, fill=(255, 255, 255, 230))
_, _, w, h = draw.textbbox((0, 0), footerLeft, font=font)
fontsize_left = min(int(fontsize * (((image.size[0]/3)-(padding))/w)), 72)
_, _, w, h = draw.textbbox((0, 0), footerMid, font=font)
fontsize_mid = min(int(fontsize * (((image.size[0]/3)-(padding))/w)), 72)
_, _, w, h = draw.textbbox((0, 0), footerRight, font=font)
fontsize_right = min(int(fontsize * (((image.size[0]/3)-(padding))/w)), 72)
font = ImageFont.truetype(textfont, min(fontsize_left, fontsize_mid, fontsize_right))
draw.text((padding, image.size[1]-padding), footerLeft, anchor='ls', font=font, fill=(255, 255, 255, 230))
draw.text((image.size[0]/2, image.size[1]-padding), footerMid, anchor='ms', font=font, fill=(255, 255, 255, 230))
draw.text((image.size[0]-padding, image.size[1]-padding), footerRight, anchor='rs', font=font, fill=(255, 255, 255, 230))
return image
if __name__ == '__main__':
testEmbed = Image.open('test_embedding.png')
data = extract_image_data_embed(testEmbed)
assert data is not None
data = embedding_from_b64(testEmbed.text['sd-ti-embedding'])
assert data is not None
image = Image.new('RGBA', (512, 512), (255, 255, 200, 255))
cap_image = caption_image_overlay(image, 'title', 'footerLeft', 'footerMid', 'footerRight')
test_embed = {'string_to_param': {'*': torch.from_numpy(np.random.random((2, 4096)))}}
embedded_image = insert_image_data_embed(cap_image, test_embed)
retrived_embed = extract_image_data_embed(embedded_image)
assert str(retrived_embed) == str(test_embed)
embedded_image2 = insert_image_data_embed(cap_image, retrived_embed)
assert embedded_image == embedded_image2
g = lcg()
shared_random = np.array([next(g) for _ in range(100)]).astype(np.uint8).tolist()
reference_random = [253, 242, 127, 44, 157, 27, 239, 133, 38, 79, 167, 4, 177,
95, 130, 79, 78, 14, 52, 215, 220, 194, 126, 28, 240, 179,
160, 153, 149, 50, 105, 14, 21, 218, 199, 18, 54, 198, 193,
38, 128, 19, 53, 195, 124, 75, 205, 12, 6, 145, 0, 28,
30, 148, 8, 45, 218, 171, 55, 249, 97, 166, 12, 35, 0,
41, 221, 122, 215, 170, 31, 113, 186, 97, 119, 31, 23, 185,
66, 140, 30, 41, 37, 63, 137, 109, 216, 55, 159, 145, 82,
204, 86, 73, 222, 44, 198, 118, 240, 97]
assert shared_random == reference_random
hunna_kay_random_sum = sum(np.array([next(g) for _ in range(100000)]).astype(np.uint8).tolist())
assert 12731374 == hunna_kay_random_sum

View File

@ -1,6 +1,12 @@
import tqdm
class LearnSchedule:
class LearnScheduleIterator:
def __init__(self, learn_rate, max_steps, cur_step=0): def __init__(self, learn_rate, max_steps, cur_step=0):
"""
specify learn_rate as "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, 1e-5:10000 until 10000
"""
pairs = learn_rate.split(',') pairs = learn_rate.split(',')
self.rates = [] self.rates = []
self.it = 0 self.it = 0
@ -32,3 +38,32 @@ class LearnSchedule:
return self.rates[self.it - 1] return self.rates[self.it - 1]
else: else:
raise StopIteration raise StopIteration
class LearnRateScheduler:
def __init__(self, learn_rate, max_steps, cur_step=0, verbose=True):
self.schedules = LearnScheduleIterator(learn_rate, max_steps, cur_step)
(self.learn_rate, self.end_step) = next(self.schedules)
self.verbose = verbose
if self.verbose:
print(f'Training at rate of {self.learn_rate} until step {self.end_step}')
self.finished = False
def apply(self, optimizer, step_number):
if step_number <= self.end_step:
return
try:
(self.learn_rate, self.end_step) = next(self.schedules)
except Exception:
self.finished = True
return
if self.verbose:
tqdm.tqdm.write(f'Training at rate of {self.learn_rate} until step {self.end_step}')
for pg in optimizer.param_groups:
pg['lr'] = self.learn_rate

View File

@ -3,11 +3,35 @@ from PIL import Image, ImageOps
import platform import platform
import sys import sys
import tqdm import tqdm
import time
from modules import shared, images from modules import shared, images
from modules.shared import opts, cmd_opts
if cmd_opts.deepdanbooru:
import modules.deepbooru as deepbooru
def preprocess(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption): def preprocess(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False):
try:
if process_caption:
shared.interrogator.load()
if process_caption_deepbooru:
deepbooru.create_deepbooru_process(opts.interrogate_deepbooru_score_threshold, deepbooru.create_deepbooru_opts())
preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru)
finally:
if process_caption:
shared.interrogator.send_blip_to_ram()
if process_caption_deepbooru:
deepbooru.release_process()
def preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False):
width = process_width width = process_width
height = process_height height = process_height
src = os.path.abspath(process_src) src = os.path.abspath(process_src)
@ -22,19 +46,28 @@ def preprocess(process_src, process_dst, process_width, process_height, process_
shared.state.textinfo = "Preprocessing..." shared.state.textinfo = "Preprocessing..."
shared.state.job_count = len(files) shared.state.job_count = len(files)
if process_caption:
shared.interrogator.load()
def save_pic_with_caption(image, index): def save_pic_with_caption(image, index):
if process_caption: caption = ""
caption = "-" + shared.interrogator.generate_caption(image)
caption = sanitize_caption(os.path.join(dst, f"{index:05}-{subindex[0]}"), caption, ".png") if process_caption:
else: caption += shared.interrogator.generate_caption(image)
caption = filename
caption = os.path.splitext(caption)[0] if process_caption_deepbooru:
caption = os.path.basename(caption) if len(caption) > 0:
caption += ", "
caption += deepbooru.get_tags_from_process(image)
filename_part = filename
filename_part = os.path.splitext(filename_part)[0]
filename_part = os.path.basename(filename_part)
basename = f"{index:05}-{subindex[0]}-{filename_part}"
image.save(os.path.join(dst, f"{basename}.png"))
if len(caption) > 0:
with open(os.path.join(dst, f"{basename}.txt"), "w", encoding="utf8") as file:
file.write(caption)
image.save(os.path.join(dst, f"{index:05}-{subindex[0]}{caption}.png"))
subindex[0] += 1 subindex[0] += 1
def save_pic(image, index): def save_pic(image, index):
@ -79,30 +112,3 @@ def preprocess(process_src, process_dst, process_width, process_height, process_
save_pic(img, index) save_pic(img, index)
shared.state.nextjob() shared.state.nextjob()
if process_caption:
shared.interrogator.send_blip_to_ram()
def sanitize_caption(base_path, original_caption, suffix):
operating_system = platform.system().lower()
if (operating_system == "windows"):
invalid_path_characters = "\\/:*?\"<>|"
max_path_length = 259
else:
invalid_path_characters = "/" #linux/macos
max_path_length = 1023
caption = original_caption
for invalid_character in invalid_path_characters:
caption = caption.replace(invalid_character, "")
fixed_path_length = len(base_path) + len(suffix)
if fixed_path_length + len(caption) <= max_path_length:
return caption
caption_tokens = caption.split()
new_caption = ""
for token in caption_tokens:
last_caption = new_caption
new_caption = new_caption + token + " "
if (len(new_caption) + fixed_path_length - 1 > max_path_length):
break
print(f"\nPath will be too long. Truncated caption: {original_caption}\nto: {last_caption}", file=sys.stderr)
return last_caption.strip()

Binary file not shown.

After

Width:  |  Height:  |  Size: 478 KiB

View File

@ -7,11 +7,15 @@ import tqdm
import html import html
import datetime import datetime
from PIL import Image, PngImagePlugin
from modules import shared, devices, sd_hijack, processing, sd_models from modules import shared, devices, sd_hijack, processing, sd_models
import modules.textual_inversion.dataset import modules.textual_inversion.dataset
from modules.textual_inversion.learn_schedule import LearnSchedule from modules.textual_inversion.learn_schedule import LearnRateScheduler
from modules.textual_inversion.image_embedding import (embedding_to_b64, embedding_from_b64,
insert_image_data_embed, extract_image_data_embed,
caption_image_overlay)
class Embedding: class Embedding:
def __init__(self, vec, name, step=None): def __init__(self, vec, name, step=None):
@ -81,6 +85,17 @@ class EmbeddingDatabase:
def process_file(path, filename): def process_file(path, filename):
name = os.path.splitext(filename)[0] name = os.path.splitext(filename)[0]
data = []
if filename.upper().endswith('.PNG'):
embed_image = Image.open(path)
if 'sd-ti-embedding' in embed_image.text:
data = embedding_from_b64(embed_image.text['sd-ti-embedding'])
name = data.get('name', name)
else:
data = extract_image_data_embed(embed_image)
name = data.get('name', name)
else:
data = torch.load(path, map_location="cpu") data = torch.load(path, map_location="cpu")
# textual inversion embeddings # textual inversion embeddings
@ -157,7 +172,7 @@ def create_embedding(name, num_vectors_per_token, init_text='*'):
return fn return fn
def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, num_repeats, create_image_every, save_embedding_every, template_file, preview_image_prompt): def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_image_prompt):
assert embedding_name, 'embedding not selected' assert embedding_name, 'embedding not selected'
shared.state.textinfo = "Initializing textual inversion training..." shared.state.textinfo = "Initializing textual inversion training..."
@ -179,11 +194,17 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
else: else:
images_dir = None images_dir = None
if create_image_every > 0 and save_image_with_stored_embedding:
images_embeds_dir = os.path.join(log_directory, "image_embeddings")
os.makedirs(images_embeds_dir, exist_ok=True)
else:
images_embeds_dir = None
cond_model = shared.sd_model.cond_stage_model cond_model = shared.sd_model.cond_stage_model
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"): with torch.autocast("cuda"):
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=num_repeats, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file) ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file)
hijack = sd_hijack.model_hijack hijack = sd_hijack.model_hijack
@ -199,32 +220,24 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
if ititial_step > steps: if ititial_step > steps:
return embedding, filename return embedding, filename
schedules = iter(LearnSchedule(learn_rate, steps, ititial_step)) scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
(learn_rate, end_step) = next(schedules) optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate)
print(f'Training at rate of {learn_rate} until step {end_step}')
optimizer = torch.optim.AdamW([embedding.vec], lr=learn_rate)
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step) pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
for i, (x, text, _) in pbar: for i, entry in pbar:
embedding.step = i + ititial_step embedding.step = i + ititial_step
if embedding.step > end_step: scheduler.apply(optimizer, embedding.step)
try: if scheduler.finished:
(learn_rate, end_step) = next(schedules)
except:
break break
tqdm.tqdm.write(f'Training at rate of {learn_rate} until step {end_step}')
for pg in optimizer.param_groups:
pg['lr'] = learn_rate
if shared.state.interrupted: if shared.state.interrupted:
break break
with torch.autocast("cuda"): with torch.autocast("cuda"):
c = cond_model([text]) c = cond_model([entry.cond_text])
x = x.to(devices.device) x = entry.latent.to(devices.device)
loss = shared.sd_model(x.unsqueeze(0), c)[0] loss = shared.sd_model(x.unsqueeze(0), c)[0]
del x del x
@ -246,7 +259,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0: if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0:
last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png') last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png')
preview_text = text if preview_image_prompt == "" else preview_image_prompt preview_text = entry.cond_text if preview_image_prompt == "" else preview_image_prompt
p = processing.StableDiffusionProcessingTxt2Img( p = processing.StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model, sd_model=shared.sd_model,
@ -262,6 +275,26 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
image = processed.images[0] image = processed.images[0]
shared.state.current_image = image shared.state.current_image = image
if save_image_with_stored_embedding and os.path.exists(last_saved_file):
last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{embedding.step}.png')
info = PngImagePlugin.PngInfo()
data = torch.load(last_saved_file)
info.add_text("sd-ti-embedding", embedding_to_b64(data))
title = "<{}>".format(data.get('name', '???'))
checkpoint = sd_models.select_checkpoint()
footer_left = checkpoint.model_name
footer_mid = '[{}]'.format(checkpoint.hash)
footer_right = '{}'.format(embedding.step)
captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right)
captioned_image = insert_image_data_embed(captioned_image, data)
captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info)
image.save(last_saved_image) image.save(last_saved_image)
last_saved_image += f", prompt: {preview_text}" last_saved_image += f", prompt: {preview_text}"
@ -272,7 +305,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
<p> <p>
Loss: {losses.mean():.7f}<br/> Loss: {losses.mean():.7f}<br/>
Step: {embedding.step}<br/> Step: {embedding.step}<br/>
Last prompt: {html.escape(text)}<br/> Last prompt: {html.escape(entry.cond_text)}<br/>
Last saved embedding: {html.escape(last_saved_file)}<br/> Last saved embedding: {html.escape(last_saved_file)}<br/>
Last saved image: {html.escape(last_saved_image)}<br/> Last saved image: {html.escape(last_saved_image)}<br/>
</p> </p>

View File

@ -131,6 +131,8 @@ def save_files(js_data, images, do_make_zip, index):
images = [images[index]] images = [images[index]]
start_index = index start_index = index
os.makedirs(opts.outdir_save, exist_ok=True)
with open(os.path.join(opts.outdir_save, "log.csv"), "a", encoding="utf8", newline='') as file: with open(os.path.join(opts.outdir_save, "log.csv"), "a", encoding="utf8", newline='') as file:
at_start = file.tell() == 0 at_start = file.tell() == 0
writer = csv.writer(file) writer = csv.writer(file)
@ -181,8 +183,15 @@ def wrap_gradio_call(func, extra_outputs=None):
try: try:
res = list(func(*args, **kwargs)) res = list(func(*args, **kwargs))
except Exception as e: except Exception as e:
# When printing out our debug argument list, do not print out more than a MB of text
max_debug_str_len = 131072 # (1024*1024)/8
print("Error completing request", file=sys.stderr) print("Error completing request", file=sys.stderr)
print("Arguments:", args, kwargs, file=sys.stderr) argStr = f"Arguments: {str(args)} {str(kwargs)}"
print(argStr[:max_debug_str_len], file=sys.stderr)
if len(argStr) > max_debug_str_len:
print(f"(Argument list truncated at {max_debug_str_len}/{len(argStr)} characters)", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr) print(traceback.format_exc(), file=sys.stderr)
shared.state.job = "" shared.state.job = ""
@ -311,12 +320,13 @@ def apply_styles(prompt, prompt_neg, style1_name, style2_name):
def interrogate(image): def interrogate(image):
prompt = shared.interrogator.interrogate(image, include_ranks=opts.interrogate_return_ranks) prompt = shared.interrogator.interrogate(image)
return gr_show(True) if prompt is None else prompt return gr_show(True) if prompt is None else prompt
def interrogate_deepbooru(image): def interrogate_deepbooru(image):
prompt = get_deepbooru_tags(image, opts.interrogate_deepbooru_score_threshold, opts.interrogate_return_ranks) prompt = get_deepbooru_tags(image)
return gr_show(True) if prompt is None else prompt return gr_show(True) if prompt is None else prompt
@ -911,7 +921,15 @@ def create_ui(wrap_gradio_gpu_call):
with gr.TabItem('Batch Process'): with gr.TabItem('Batch Process'):
image_batch = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file") image_batch = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file")
with gr.Tabs(elem_id="extras_resize_mode"):
with gr.TabItem('Scale by'):
upscaling_resize = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Resize", value=2) upscaling_resize = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Resize", value=2)
with gr.TabItem('Scale to'):
with gr.Group():
with gr.Row():
upscaling_resize_w = gr.Number(label="Width", value=512, precision=0)
upscaling_resize_h = gr.Number(label="Height", value=512, precision=0)
upscaling_crop = gr.Checkbox(label='Crop to fit', value=True)
with gr.Group(): with gr.Group():
extras_upscaler_1 = gr.Radio(label='Upscaler 1', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") extras_upscaler_1 = gr.Radio(label='Upscaler 1', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index")
@ -942,6 +960,7 @@ def create_ui(wrap_gradio_gpu_call):
fn=wrap_gradio_gpu_call(modules.extras.run_extras), fn=wrap_gradio_gpu_call(modules.extras.run_extras),
_js="get_extras_tab_index", _js="get_extras_tab_index",
inputs=[ inputs=[
dummy_component,
dummy_component, dummy_component,
extras_image, extras_image,
image_batch, image_batch,
@ -949,6 +968,9 @@ def create_ui(wrap_gradio_gpu_call):
codeformer_visibility, codeformer_visibility,
codeformer_weight, codeformer_weight,
upscaling_resize, upscaling_resize,
upscaling_resize_w,
upscaling_resize_h,
upscaling_crop,
extras_upscaler_1, extras_upscaler_1,
extras_upscaler_2, extras_upscaler_2,
extras_upscaler_2_visibility, extras_upscaler_2_visibility,
@ -1013,14 +1035,14 @@ def create_ui(wrap_gradio_gpu_call):
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
with gr.Blocks() as textual_inversion_interface: with gr.Blocks() as train_interface:
with gr.Row().style(equal_height=False): with gr.Row().style(equal_height=False):
with gr.Column():
with gr.Group():
gr.HTML(value="<p style='margin-bottom: 0.7em'>See <b><a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\">wiki</a></b> for detailed explanation.</p>") gr.HTML(value="<p style='margin-bottom: 0.7em'>See <b><a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\">wiki</a></b> for detailed explanation.</p>")
gr.HTML(value="<p style='margin-bottom: 0.7em'>Create a new embedding</p>") with gr.Row().style(equal_height=False):
with gr.Tabs(elem_id="train_tabs"):
with gr.Tab(label="Create embedding"):
new_embedding_name = gr.Textbox(label="Name") new_embedding_name = gr.Textbox(label="Name")
initialization_text = gr.Textbox(label="Initialization text", value="*") initialization_text = gr.Textbox(label="Initialization text", value="*")
nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1) nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1)
@ -1032,9 +1054,7 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Column(): with gr.Column():
create_embedding = gr.Button(value="Create embedding", variant='primary') create_embedding = gr.Button(value="Create embedding", variant='primary')
with gr.Group(): with gr.Tab(label="Create hypernetwork"):
gr.HTML(value="<p style='margin-bottom: 0.7em'>Create a new hypernetwork</p>")
new_hypernetwork_name = gr.Textbox(label="Name") new_hypernetwork_name = gr.Textbox(label="Name")
new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"]) new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"])
@ -1045,9 +1065,7 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Column(): with gr.Column():
create_hypernetwork = gr.Button(value="Create hypernetwork", variant='primary') create_hypernetwork = gr.Button(value="Create hypernetwork", variant='primary')
with gr.Group(): with gr.Tab(label="Preprocess images"):
gr.HTML(value="<p style='margin-bottom: 0.7em'>Preprocess images</p>")
process_src = gr.Textbox(label='Source directory') process_src = gr.Textbox(label='Source directory')
process_dst = gr.Textbox(label='Destination directory') process_dst = gr.Textbox(label='Destination directory')
process_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512) process_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
@ -1056,7 +1074,8 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Row(): with gr.Row():
process_flip = gr.Checkbox(label='Create flipped copies') process_flip = gr.Checkbox(label='Create flipped copies')
process_split = gr.Checkbox(label='Split oversized images into two') process_split = gr.Checkbox(label='Split oversized images into two')
process_caption = gr.Checkbox(label='Use BLIP caption as filename') process_caption = gr.Checkbox(label='Use BLIP for caption')
process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True if cmd_opts.deepdanbooru else False)
with gr.Row(): with gr.Row():
with gr.Column(scale=3): with gr.Column(scale=3):
@ -1065,7 +1084,7 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Column(): with gr.Column():
run_preprocess = gr.Button(value="Preprocess", variant='primary') run_preprocess = gr.Button(value="Preprocess", variant='primary')
with gr.Group(): with gr.Tab(label="Train"):
gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding; must specify a directory with a set of 1:1 ratio images</p>") gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding; must specify a directory with a set of 1:1 ratio images</p>")
train_embedding_name = gr.Dropdown(label='Embedding', choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())) train_embedding_name = gr.Dropdown(label='Embedding', choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', choices=[x for x in shared.hypernetworks.keys()]) train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', choices=[x for x in shared.hypernetworks.keys()])
@ -1076,9 +1095,9 @@ def create_ui(wrap_gradio_gpu_call):
training_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512) training_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
training_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512) training_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
steps = gr.Number(label='Max steps', value=100000, precision=0) steps = gr.Number(label='Max steps', value=100000, precision=0)
num_repeats = gr.Number(label='Number of repeats for a single input image per epoch', value=100, precision=0)
create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0) create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0)
save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0) save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0)
save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True)
preview_image_prompt = gr.Textbox(label='Preview prompt', value="") preview_image_prompt = gr.Textbox(label='Preview prompt', value="")
with gr.Row(): with gr.Row():
@ -1134,6 +1153,7 @@ def create_ui(wrap_gradio_gpu_call):
process_flip, process_flip,
process_split, process_split,
process_caption, process_caption,
process_caption_deepbooru
], ],
outputs=[ outputs=[
ti_output, ti_output,
@ -1152,10 +1172,10 @@ def create_ui(wrap_gradio_gpu_call):
training_width, training_width,
training_height, training_height,
steps, steps,
num_repeats,
create_image_every, create_image_every,
save_embedding_every, save_embedding_every,
template_file, template_file,
save_image_with_stored_embedding,
preview_image_prompt, preview_image_prompt,
], ],
outputs=[ outputs=[
@ -1361,7 +1381,7 @@ Requested path was: {f}
(extras_interface, "Extras", "extras"), (extras_interface, "Extras", "extras"),
(pnginfo_interface, "PNG Info", "pnginfo"), (pnginfo_interface, "PNG Info", "pnginfo"),
(modelmerger_interface, "Checkpoint Merger", "modelmerger"), (modelmerger_interface, "Checkpoint Merger", "modelmerger"),
(textual_inversion_interface, "Textual inversion", "ti"), (train_interface, "Train", "ti"),
(settings_interface, "Settings", "settings"), (settings_interface, "Settings", "settings"),
] ]

View File

@ -129,8 +129,6 @@ class Script(scripts.Script):
return [original_prompt, original_negative_prompt, cfg, st, randomness, sigma_adjustment] return [original_prompt, original_negative_prompt, cfg, st, randomness, sigma_adjustment]
def run(self, p, original_prompt, original_negative_prompt, cfg, st, randomness, sigma_adjustment): def run(self, p, original_prompt, original_negative_prompt, cfg, st, randomness, sigma_adjustment):
p.batch_size = 1
p.batch_count = 1
def sample_extra(conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength): def sample_extra(conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
@ -154,7 +152,7 @@ class Script(scripts.Script):
rec_noise = find_noise_for_image(p, cond, uncond, cfg, st) rec_noise = find_noise_for_image(p, cond, uncond, cfg, st)
self.cache = Cached(rec_noise, cfg, st, lat, original_prompt, original_negative_prompt, sigma_adjustment) self.cache = Cached(rec_noise, cfg, st, lat, original_prompt, original_negative_prompt, sigma_adjustment)
rand_noise = processing.create_random_tensors(p.init_latent.shape[1:], [p.seed + x + 1 for x in range(p.init_latent.shape[0])]) rand_noise = processing.create_random_tensors(p.init_latent.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, seed_resize_from_h=p.seed_resize_from_h, seed_resize_from_w=p.seed_resize_from_w, p=p)
combined_noise = ((1 - randomness) * rec_noise + randomness * rand_noise) / ((randomness**2 + (1-randomness)**2) ** 0.5) combined_noise = ((1 - randomness) * rec_noise + randomness * rand_noise) / ((randomness**2 + (1-randomness)**2) ** 0.5)

View File

@ -77,14 +77,42 @@ def apply_sampler(p, x, xs):
p.sampler_index = sampler_index p.sampler_index = sampler_index
def confirm_samplers(p, xs):
samplers_dict = build_samplers_dict(p)
for x in xs:
if x.lower() not in samplers_dict.keys():
raise RuntimeError(f"Unknown sampler: {x}")
def apply_checkpoint(p, x, xs): def apply_checkpoint(p, x, xs):
info = modules.sd_models.get_closet_checkpoint_match(x) info = modules.sd_models.get_closet_checkpoint_match(x)
assert info is not None, f'Checkpoint for {x} not found' if info is None:
raise RuntimeError(f"Unknown checkpoint: {x}")
modules.sd_models.reload_model_weights(shared.sd_model, info) modules.sd_models.reload_model_weights(shared.sd_model, info)
def confirm_checkpoints(p, xs):
for x in xs:
if modules.sd_models.get_closet_checkpoint_match(x) is None:
raise RuntimeError(f"Unknown checkpoint: {x}")
def apply_hypernetwork(p, x, xs): def apply_hypernetwork(p, x, xs):
hypernetwork.load_hypernetwork(x) if x.lower() in ["", "none"]:
name = None
else:
name = hypernetwork.find_closest_hypernetwork_name(x)
if not name:
raise RuntimeError(f"Unknown hypernetwork: {x}")
hypernetwork.load_hypernetwork(name)
def confirm_hypernetworks(p, xs):
for x in xs:
if x.lower() in ["", "none"]:
continue
if not hypernetwork.find_closest_hypernetwork_name(x):
raise RuntimeError(f"Unknown hypernetwork: {x}")
def apply_clip_skip(p, x, xs): def apply_clip_skip(p, x, xs):
@ -121,29 +149,29 @@ def str_permutations(x):
return x return x
AxisOption = namedtuple("AxisOption", ["label", "type", "apply", "format_value"]) AxisOption = namedtuple("AxisOption", ["label", "type", "apply", "format_value", "confirm"])
AxisOptionImg2Img = namedtuple("AxisOptionImg2Img", ["label", "type", "apply", "format_value"]) AxisOptionImg2Img = namedtuple("AxisOptionImg2Img", ["label", "type", "apply", "format_value", "confirm"])
axis_options = [ axis_options = [
AxisOption("Nothing", str, do_nothing, format_nothing), AxisOption("Nothing", str, do_nothing, format_nothing, None),
AxisOption("Seed", int, apply_field("seed"), format_value_add_label), AxisOption("Seed", int, apply_field("seed"), format_value_add_label, None),
AxisOption("Var. seed", int, apply_field("subseed"), format_value_add_label), AxisOption("Var. seed", int, apply_field("subseed"), format_value_add_label, None),
AxisOption("Var. strength", float, apply_field("subseed_strength"), format_value_add_label), AxisOption("Var. strength", float, apply_field("subseed_strength"), format_value_add_label, None),
AxisOption("Steps", int, apply_field("steps"), format_value_add_label), AxisOption("Steps", int, apply_field("steps"), format_value_add_label, None),
AxisOption("CFG Scale", float, apply_field("cfg_scale"), format_value_add_label), AxisOption("CFG Scale", float, apply_field("cfg_scale"), format_value_add_label, None),
AxisOption("Prompt S/R", str, apply_prompt, format_value), AxisOption("Prompt S/R", str, apply_prompt, format_value, None),
AxisOption("Prompt order", str_permutations, apply_order, format_value_join_list), AxisOption("Prompt order", str_permutations, apply_order, format_value_join_list, None),
AxisOption("Sampler", str, apply_sampler, format_value), AxisOption("Sampler", str, apply_sampler, format_value, confirm_samplers),
AxisOption("Checkpoint name", str, apply_checkpoint, format_value), AxisOption("Checkpoint name", str, apply_checkpoint, format_value, confirm_checkpoints),
AxisOption("Hypernetwork", str, apply_hypernetwork, format_value), AxisOption("Hypernetwork", str, apply_hypernetwork, format_value, confirm_hypernetworks),
AxisOption("Sigma Churn", float, apply_field("s_churn"), format_value_add_label), AxisOption("Sigma Churn", float, apply_field("s_churn"), format_value_add_label, None),
AxisOption("Sigma min", float, apply_field("s_tmin"), format_value_add_label), AxisOption("Sigma min", float, apply_field("s_tmin"), format_value_add_label, None),
AxisOption("Sigma max", float, apply_field("s_tmax"), format_value_add_label), AxisOption("Sigma max", float, apply_field("s_tmax"), format_value_add_label, None),
AxisOption("Sigma noise", float, apply_field("s_noise"), format_value_add_label), AxisOption("Sigma noise", float, apply_field("s_noise"), format_value_add_label, None),
AxisOption("Eta", float, apply_field("eta"), format_value_add_label), AxisOption("Eta", float, apply_field("eta"), format_value_add_label, None),
AxisOption("Clip skip", int, apply_clip_skip, format_value_add_label), AxisOption("Clip skip", int, apply_clip_skip, format_value_add_label, None),
AxisOptionImg2Img("Denoising", float, apply_field("denoising_strength"), format_value_add_label), # as it is now all AxisOptionImg2Img items must go after AxisOption ones AxisOptionImg2Img("Denoising", float, apply_field("denoising_strength"), format_value_add_label, None), # as it is now all AxisOptionImg2Img items must go after AxisOption ones
] ]
@ -271,15 +299,8 @@ class Script(scripts.Script):
valslist = [opt.type(x) for x in valslist] valslist = [opt.type(x) for x in valslist]
# Confirm options are valid before starting # Confirm options are valid before starting
if opt.label == "Sampler": if opt.confirm:
samplers_dict = build_samplers_dict(p) opt.confirm(p, valslist)
for sampler_val in valslist:
if sampler_val.lower() not in samplers_dict.keys():
raise RuntimeError(f"Unknown sampler: {sampler_val}")
elif opt.label == "Checkpoint name":
for ckpt_val in valslist:
if modules.sd_models.get_closet_checkpoint_match(ckpt_val) is None:
raise RuntimeError(f"Checkpoint for {ckpt_val} not found")
return valslist return valslist

View File

@ -31,12 +31,7 @@ from modules.paths import script_path
from modules.shared import cmd_opts from modules.shared import cmd_opts
import modules.hypernetworks.hypernetwork import modules.hypernetworks.hypernetwork
modelloader.cleanup_models()
modules.sd_models.setup_model()
codeformer.setup_model(cmd_opts.codeformer_models_path)
gfpgan.setup_model(cmd_opts.gfpgan_models_path)
shared.face_restorers.append(modules.face_restoration.FaceRestoration())
modelloader.load_upscalers()
queue_lock = threading.Lock() queue_lock = threading.Lock()
@ -78,15 +73,24 @@ def wrap_gradio_gpu_call(func, extra_outputs=None):
return modules.ui.wrap_gradio_call(f, extra_outputs=extra_outputs) return modules.ui.wrap_gradio_call(f, extra_outputs=extra_outputs)
modules.scripts.load_scripts(os.path.join(script_path, "scripts")) def initialize():
modelloader.cleanup_models()
modules.sd_models.setup_model()
codeformer.setup_model(cmd_opts.codeformer_models_path)
gfpgan.setup_model(cmd_opts.gfpgan_models_path)
shared.face_restorers.append(modules.face_restoration.FaceRestoration())
modelloader.load_upscalers()
shared.sd_model = modules.sd_models.load_model() modules.scripts.load_scripts(os.path.join(script_path, "scripts"))
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(shared.sd_model)))
shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork))) shared.sd_model = modules.sd_models.load_model()
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(shared.sd_model)))
shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork)))
def webui(): def webui():
initialize()
# make the program just exit at ctrl+c without waiting for anything # make the program just exit at ctrl+c without waiting for anything
def sigint_handler(sig, frame): def sigint_handler(sig, frame):
print(f'Interrupted with signal {sig} in {frame}') print(f'Interrupted with signal {sig} in {frame}')
@ -98,7 +102,7 @@ def webui():
demo = modules.ui.create_ui(wrap_gradio_gpu_call=wrap_gradio_gpu_call) demo = modules.ui.create_ui(wrap_gradio_gpu_call=wrap_gradio_gpu_call)
app,local_url,share_url = demo.launch( app, local_url, share_url = demo.launch(
share=cmd_opts.share, share=cmd_opts.share,
server_name="0.0.0.0" if cmd_opts.listen else None, server_name="0.0.0.0" if cmd_opts.listen else None,
server_port=cmd_opts.port, server_port=cmd_opts.port,
@ -129,6 +133,5 @@ def webui():
print('Restarting Gradio') print('Restarting Gradio')
if __name__ == "__main__": if __name__ == "__main__":
webui() webui()