Merge branch 'master' into patch-1
This commit is contained in:
commit
6165f07e74
33
javascript/generationParams.js
Normal file
33
javascript/generationParams.js
Normal file
|
@ -0,0 +1,33 @@
|
||||||
|
// attaches listeners to the txt2img and img2img galleries to update displayed generation param text when the image changes
|
||||||
|
|
||||||
|
let txt2img_gallery, img2img_gallery, modal = undefined;
|
||||||
|
onUiUpdate(function(){
|
||||||
|
if (!txt2img_gallery) {
|
||||||
|
txt2img_gallery = attachGalleryListeners("txt2img")
|
||||||
|
}
|
||||||
|
if (!img2img_gallery) {
|
||||||
|
img2img_gallery = attachGalleryListeners("img2img")
|
||||||
|
}
|
||||||
|
if (!modal) {
|
||||||
|
modal = gradioApp().getElementById('lightboxModal')
|
||||||
|
modalObserver.observe(modal, { attributes : true, attributeFilter : ['style'] });
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let modalObserver = new MutationObserver(function(mutations) {
|
||||||
|
mutations.forEach(function(mutationRecord) {
|
||||||
|
let selectedTab = gradioApp().querySelector('#tabs div button.bg-white')?.innerText
|
||||||
|
if (mutationRecord.target.style.display === 'none' && selectedTab === 'txt2img' || selectedTab === 'img2img')
|
||||||
|
gradioApp().getElementById(selectedTab+"_generation_info_button").click()
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
function attachGalleryListeners(tab_name) {
|
||||||
|
gallery = gradioApp().querySelector('#'+tab_name+'_gallery')
|
||||||
|
gallery?.addEventListener('click', () => gradioApp().getElementById(tab_name+"_generation_info_button").click());
|
||||||
|
gallery?.addEventListener('keydown', (e) => {
|
||||||
|
if (e.keyCode == 37 || e.keyCode == 39) // left or right arrow
|
||||||
|
gradioApp().getElementById(tab_name+"_generation_info_button").click()
|
||||||
|
});
|
||||||
|
return gallery;
|
||||||
|
}
|
|
@ -15,6 +15,9 @@ from modules.sd_models import checkpoints_list
|
||||||
from modules.realesrgan_model import get_realesrgan_models
|
from modules.realesrgan_model import get_realesrgan_models
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
if shared.cmd_opts.deepdanbooru:
|
||||||
|
from modules.deepbooru import get_deepbooru_tags
|
||||||
|
|
||||||
def upscaler_to_index(name: str):
|
def upscaler_to_index(name: str):
|
||||||
try:
|
try:
|
||||||
return [x.name.lower() for x in shared.sd_upscalers].index(name.lower())
|
return [x.name.lower() for x in shared.sd_upscalers].index(name.lower())
|
||||||
|
@ -220,11 +223,20 @@ class Api:
|
||||||
if image_b64 is None:
|
if image_b64 is None:
|
||||||
raise HTTPException(status_code=404, detail="Image not found")
|
raise HTTPException(status_code=404, detail="Image not found")
|
||||||
|
|
||||||
img = self.__base64_to_image(image_b64)
|
img = decode_base64_to_image(image_b64)
|
||||||
|
img = img.convert('RGB')
|
||||||
|
|
||||||
# Override object param
|
# Override object param
|
||||||
with self.queue_lock:
|
with self.queue_lock:
|
||||||
|
if interrogatereq.model == "clip":
|
||||||
processed = shared.interrogator.interrogate(img)
|
processed = shared.interrogator.interrogate(img)
|
||||||
|
elif interrogatereq.model == "deepdanbooru":
|
||||||
|
if shared.cmd_opts.deepdanbooru:
|
||||||
|
processed = get_deepbooru_tags(img)
|
||||||
|
else:
|
||||||
|
raise HTTPException(status_code=404, detail="Model not found. Add --deepdanbooru when launching for using the model.")
|
||||||
|
else:
|
||||||
|
raise HTTPException(status_code=404, detail="Model not found")
|
||||||
|
|
||||||
return InterrogateResponse(caption=processed)
|
return InterrogateResponse(caption=processed)
|
||||||
|
|
||||||
|
|
|
@ -170,6 +170,7 @@ class ProgressResponse(BaseModel):
|
||||||
|
|
||||||
class InterrogateRequest(BaseModel):
|
class InterrogateRequest(BaseModel):
|
||||||
image: str = Field(default="", title="Image", description="Image to work on, must be a Base64 string containing the image's data.")
|
image: str = Field(default="", title="Image", description="Image to work on, must be a Base64 string containing the image's data.")
|
||||||
|
model: str = Field(default="clip", title="Model", description="The interrogate model used.")
|
||||||
|
|
||||||
class InterrogateResponse(BaseModel):
|
class InterrogateResponse(BaseModel):
|
||||||
caption: str = Field(default=None, title="Caption", description="The generated caption for the image.")
|
caption: str = Field(default=None, title="Caption", description="The generated caption for the image.")
|
||||||
|
|
|
@ -1,14 +1,23 @@
|
||||||
from pyngrok import ngrok, conf, exception
|
from pyngrok import ngrok, conf, exception
|
||||||
|
|
||||||
|
|
||||||
def connect(token, port, region):
|
def connect(token, port, region):
|
||||||
|
account = None
|
||||||
if token == None:
|
if token == None:
|
||||||
token = 'None'
|
token = 'None'
|
||||||
|
else:
|
||||||
|
if ':' in token:
|
||||||
|
# token = authtoken:username:password
|
||||||
|
account = token.split(':')[1] + ':' + token.split(':')[-1]
|
||||||
|
token = token.split(':')[0]
|
||||||
|
|
||||||
config = conf.PyngrokConfig(
|
config = conf.PyngrokConfig(
|
||||||
auth_token=token, region=region
|
auth_token=token, region=region
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
|
if account == None:
|
||||||
public_url = ngrok.connect(port, pyngrok_config=config, bind_tls=True).public_url
|
public_url = ngrok.connect(port, pyngrok_config=config, bind_tls=True).public_url
|
||||||
|
else:
|
||||||
|
public_url = ngrok.connect(port, pyngrok_config=config, bind_tls=True, auth=account).public_url
|
||||||
except exception.PyngrokNgrokError:
|
except exception.PyngrokNgrokError:
|
||||||
print(f'Invalid ngrok authtoken, ngrok connection aborted.\n'
|
print(f'Invalid ngrok authtoken, ngrok connection aborted.\n'
|
||||||
f'Your token: {token}, get the right one on https://dashboard.ngrok.com/get-started/your-authtoken')
|
f'Your token: {token}, get the right one on https://dashboard.ngrok.com/get-started/your-authtoken')
|
||||||
|
|
|
@ -163,13 +163,21 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"):
|
||||||
checkpoint_file = checkpoint_info.filename
|
checkpoint_file = checkpoint_info.filename
|
||||||
sd_model_hash = checkpoint_info.hash
|
sd_model_hash = checkpoint_info.hash
|
||||||
|
|
||||||
if shared.opts.sd_checkpoint_cache > 0 and hasattr(model, "sd_checkpoint_info"):
|
cache_enabled = shared.opts.sd_checkpoint_cache > 0
|
||||||
|
|
||||||
|
if cache_enabled:
|
||||||
sd_vae.restore_base_vae(model)
|
sd_vae.restore_base_vae(model)
|
||||||
checkpoints_loaded[model.sd_checkpoint_info] = model.state_dict().copy()
|
|
||||||
|
|
||||||
vae_file = sd_vae.resolve_vae(checkpoint_file, vae_file=vae_file)
|
vae_file = sd_vae.resolve_vae(checkpoint_file, vae_file=vae_file)
|
||||||
|
|
||||||
if checkpoint_info not in checkpoints_loaded:
|
if cache_enabled and checkpoint_info in checkpoints_loaded:
|
||||||
|
# use checkpoint cache
|
||||||
|
vae_name = sd_vae.get_filename(vae_file) if vae_file else None
|
||||||
|
vae_message = f" with {vae_name} VAE" if vae_name else ""
|
||||||
|
print(f"Loading weights [{sd_model_hash}]{vae_message} from cache")
|
||||||
|
model.load_state_dict(checkpoints_loaded[checkpoint_info])
|
||||||
|
else:
|
||||||
|
# load from file
|
||||||
print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}")
|
print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}")
|
||||||
|
|
||||||
pl_sd = torch.load(checkpoint_file, map_location=shared.weight_load_location)
|
pl_sd = torch.load(checkpoint_file, map_location=shared.weight_load_location)
|
||||||
|
@ -181,6 +189,10 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"):
|
||||||
model.load_state_dict(sd, strict=False)
|
model.load_state_dict(sd, strict=False)
|
||||||
del sd
|
del sd
|
||||||
|
|
||||||
|
if cache_enabled:
|
||||||
|
# cache newly loaded model
|
||||||
|
checkpoints_loaded[checkpoint_info] = model.state_dict().copy()
|
||||||
|
|
||||||
if shared.cmd_opts.opt_channelslast:
|
if shared.cmd_opts.opt_channelslast:
|
||||||
model.to(memory_format=torch.channels_last)
|
model.to(memory_format=torch.channels_last)
|
||||||
|
|
||||||
|
@ -199,14 +211,9 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"):
|
||||||
|
|
||||||
model.first_stage_model.to(devices.dtype_vae)
|
model.first_stage_model.to(devices.dtype_vae)
|
||||||
|
|
||||||
else:
|
# clean up cache if limit is reached
|
||||||
vae_name = sd_vae.get_filename(vae_file) if vae_file else None
|
if cache_enabled:
|
||||||
vae_message = f" with {vae_name} VAE" if vae_name else ""
|
while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache + 1: # we need to count the current model
|
||||||
print(f"Loading weights [{sd_model_hash}]{vae_message} from cache")
|
|
||||||
model.load_state_dict(checkpoints_loaded[checkpoint_info])
|
|
||||||
|
|
||||||
if shared.opts.sd_checkpoint_cache > 0:
|
|
||||||
while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:
|
|
||||||
checkpoints_loaded.popitem(last=False) # LRU
|
checkpoints_loaded.popitem(last=False) # LRU
|
||||||
|
|
||||||
model.sd_model_hash = sd_model_hash
|
model.sd_model_hash = sd_model_hash
|
||||||
|
|
|
@ -319,6 +319,8 @@ options_templates.update(options_section(('system', "System"), {
|
||||||
|
|
||||||
options_templates.update(options_section(('training', "Training"), {
|
options_templates.update(options_section(('training', "Training"), {
|
||||||
"unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."),
|
"unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."),
|
||||||
|
"shuffle_tags": OptionInfo(False, "Shuffleing tags by ',' when create texts."),
|
||||||
|
"tag_drop_out": OptionInfo(0, "Dropout tags when create texts", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.1}),
|
||||||
"save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training can be resumed with HN itself and matching optim file."),
|
"save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training can be resumed with HN itself and matching optim file."),
|
||||||
"dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
|
"dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
|
||||||
"dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
|
"dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
|
||||||
|
|
|
@ -98,7 +98,12 @@ class PersonalizedBase(Dataset):
|
||||||
def create_text(self, filename_text):
|
def create_text(self, filename_text):
|
||||||
text = random.choice(self.lines)
|
text = random.choice(self.lines)
|
||||||
text = text.replace("[name]", self.placeholder_token)
|
text = text.replace("[name]", self.placeholder_token)
|
||||||
text = text.replace("[filewords]", filename_text)
|
tags = filename_text.split(',')
|
||||||
|
if shared.opts.tag_drop_out != 0:
|
||||||
|
tags = [t for t in tags if random.random() > shared.opts.tag_drop_out]
|
||||||
|
if shared.opts.shuffle_tags:
|
||||||
|
random.shuffle(tags)
|
||||||
|
text = text.replace("[filewords]", ','.join(tags))
|
||||||
return text
|
return text
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
|
|
|
@ -566,6 +566,19 @@ def apply_setting(key, value):
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def update_generation_info(args):
|
||||||
|
generation_info, html_info, img_index = args
|
||||||
|
try:
|
||||||
|
generation_info = json.loads(generation_info)
|
||||||
|
if img_index < 0 or img_index >= len(generation_info["infotexts"]):
|
||||||
|
return html_info
|
||||||
|
return plaintext_to_html(generation_info["infotexts"][img_index])
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
# if the json parse or anything else fails, just return the old html_info
|
||||||
|
return html_info
|
||||||
|
|
||||||
|
|
||||||
def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
|
def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
|
||||||
def refresh():
|
def refresh():
|
||||||
refresh_method()
|
refresh_method()
|
||||||
|
@ -638,6 +651,15 @@ Requested path was: {f}
|
||||||
with gr.Group():
|
with gr.Group():
|
||||||
html_info = gr.HTML()
|
html_info = gr.HTML()
|
||||||
generation_info = gr.Textbox(visible=False)
|
generation_info = gr.Textbox(visible=False)
|
||||||
|
if tabname == 'txt2img' or tabname == 'img2img':
|
||||||
|
generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button")
|
||||||
|
generation_info_button.click(
|
||||||
|
fn=update_generation_info,
|
||||||
|
_js="(x, y) => [x, y, selected_gallery_index()]",
|
||||||
|
inputs=[generation_info, html_info],
|
||||||
|
outputs=[html_info],
|
||||||
|
preprocess=False
|
||||||
|
)
|
||||||
|
|
||||||
save.click(
|
save.click(
|
||||||
fn=wrap_gradio_call(save_files),
|
fn=wrap_gradio_call(save_files),
|
||||||
|
|
|
@ -80,6 +80,8 @@ class Script(scripts.Script):
|
||||||
grid = images.image_grid(processed.images, p.batch_size, rows=1 << ((len(prompt_matrix_parts) - 1) // 2))
|
grid = images.image_grid(processed.images, p.batch_size, rows=1 << ((len(prompt_matrix_parts) - 1) // 2))
|
||||||
grid = images.draw_prompt_matrix(grid, p.width, p.height, prompt_matrix_parts)
|
grid = images.draw_prompt_matrix(grid, p.width, p.height, prompt_matrix_parts)
|
||||||
processed.images.insert(0, grid)
|
processed.images.insert(0, grid)
|
||||||
|
processed.index_of_first_image = 1
|
||||||
|
processed.infotexts.insert(0, processed.infotexts[0])
|
||||||
|
|
||||||
if opts.grid_save:
|
if opts.grid_save:
|
||||||
images.save_image(processed.images[0], p.outpath_grids, "prompt_matrix", prompt=original_prompt, seed=processed.seed, grid=True, p=p)
|
images.save_image(processed.images[0], p.outpath_grids, "prompt_matrix", prompt=original_prompt, seed=processed.seed, grid=True, p=p)
|
||||||
|
|
|
@ -145,6 +145,8 @@ class Script(scripts.Script):
|
||||||
state.job_count = job_count
|
state.job_count = job_count
|
||||||
|
|
||||||
images = []
|
images = []
|
||||||
|
all_prompts = []
|
||||||
|
infotexts = []
|
||||||
for n, args in enumerate(jobs):
|
for n, args in enumerate(jobs):
|
||||||
state.job = f"{state.job_no + 1} out of {state.job_count}"
|
state.job = f"{state.job_no + 1} out of {state.job_count}"
|
||||||
|
|
||||||
|
@ -157,5 +159,7 @@ class Script(scripts.Script):
|
||||||
|
|
||||||
if checkbox_iterate:
|
if checkbox_iterate:
|
||||||
p.seed = p.seed + (p.batch_size * p.n_iter)
|
p.seed = p.seed + (p.batch_size * p.n_iter)
|
||||||
|
all_prompts += proc.all_prompts
|
||||||
|
infotexts += proc.infotexts
|
||||||
|
|
||||||
return Processed(p, images, p.seed, "")
|
return Processed(p, images, p.seed, "", all_prompts=all_prompts, infotexts=infotexts)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user