Merge remote-tracking branch 'upstream/master' into PowerShell
This commit is contained in:
commit
fbdc89c9b1
configs
extensions-builtin
javascript
models/VAE-approx
modules
api
codeformer
generation_parameters_copypaste.pyhypernetworks
images.pyimg2img.pyinterrogate.pymemmon.pyprocessing.pysafe.pyscript_callbacks.pysd_hijack.pysd_hijack_clip.pysd_hijack_inpainting.pysd_hijack_xlmr.pysd_models.pysd_samplers.pysd_vae.pysd_vae_approx.pyshared.pytextual_inversion
txt2img.pyui.pyui_components.pyui_tempdir.pyxlmr.pyscripts
style.cssv2-inference-v.yaml
72
configs/alt-diffusion-inference.yaml
Normal file
72
configs/alt-diffusion-inference.yaml
Normal file
|
@ -0,0 +1,72 @@
|
|||
model:
|
||||
base_learning_rate: 1.0e-04
|
||||
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||
params:
|
||||
linear_start: 0.00085
|
||||
linear_end: 0.0120
|
||||
num_timesteps_cond: 1
|
||||
log_every_t: 200
|
||||
timesteps: 1000
|
||||
first_stage_key: "jpg"
|
||||
cond_stage_key: "txt"
|
||||
image_size: 64
|
||||
channels: 4
|
||||
cond_stage_trainable: false # Note: different from the one we trained before
|
||||
conditioning_key: crossattn
|
||||
monitor: val/loss_simple_ema
|
||||
scale_factor: 0.18215
|
||||
use_ema: False
|
||||
|
||||
scheduler_config: # 10000 warmup steps
|
||||
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||
params:
|
||||
warm_up_steps: [ 10000 ]
|
||||
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||
f_start: [ 1.e-6 ]
|
||||
f_max: [ 1. ]
|
||||
f_min: [ 1. ]
|
||||
|
||||
unet_config:
|
||||
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||
params:
|
||||
image_size: 32 # unused
|
||||
in_channels: 4
|
||||
out_channels: 4
|
||||
model_channels: 320
|
||||
attention_resolutions: [ 4, 2, 1 ]
|
||||
num_res_blocks: 2
|
||||
channel_mult: [ 1, 2, 4, 4 ]
|
||||
num_heads: 8
|
||||
use_spatial_transformer: True
|
||||
transformer_depth: 1
|
||||
context_dim: 768
|
||||
use_checkpoint: True
|
||||
legacy: False
|
||||
|
||||
first_stage_config:
|
||||
target: ldm.models.autoencoder.AutoencoderKL
|
||||
params:
|
||||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
ddconfig:
|
||||
double_z: true
|
||||
z_channels: 4
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult:
|
||||
- 1
|
||||
- 2
|
||||
- 4
|
||||
- 4
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: []
|
||||
dropout: 0.0
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
|
||||
cond_stage_config:
|
||||
target: modules.xlmr.BertSeriesModelWithTransformation
|
||||
params:
|
||||
name: "XLMR-Large"
|
|
@ -26,7 +26,7 @@ class LDSR:
|
|||
global cached_ldsr_model
|
||||
|
||||
if shared.opts.ldsr_cached and cached_ldsr_model is not None:
|
||||
print(f"Loading model from cache")
|
||||
print("Loading model from cache")
|
||||
model: torch.nn.Module = cached_ldsr_model
|
||||
else:
|
||||
print(f"Loading model from {self.modelPath}")
|
||||
|
|
50
extensions-builtin/roll-artist/scripts/roll-artist.py
Normal file
50
extensions-builtin/roll-artist/scripts/roll-artist.py
Normal file
|
@ -0,0 +1,50 @@
|
|||
import random
|
||||
|
||||
from modules import script_callbacks, shared
|
||||
import gradio as gr
|
||||
|
||||
art_symbol = '\U0001f3a8' # 🎨
|
||||
global_prompt = None
|
||||
related_ids = {"txt2img_prompt", "txt2img_clear_prompt", "img2img_prompt", "img2img_clear_prompt" }
|
||||
|
||||
|
||||
def roll_artist(prompt):
|
||||
allowed_cats = set([x for x in shared.artist_db.categories() if len(shared.opts.random_artist_categories)==0 or x in shared.opts.random_artist_categories])
|
||||
artist = random.choice([x for x in shared.artist_db.artists if x.category in allowed_cats])
|
||||
|
||||
return prompt + ", " + artist.name if prompt != '' else artist.name
|
||||
|
||||
|
||||
def add_roll_button(prompt):
|
||||
roll = gr.Button(value=art_symbol, elem_id="roll", visible=len(shared.artist_db.artists) > 0)
|
||||
|
||||
roll.click(
|
||||
fn=roll_artist,
|
||||
_js="update_txt2img_tokens",
|
||||
inputs=[
|
||||
prompt,
|
||||
],
|
||||
outputs=[
|
||||
prompt,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def after_component(component, **kwargs):
|
||||
global global_prompt
|
||||
|
||||
elem_id = kwargs.get('elem_id', None)
|
||||
if elem_id not in related_ids:
|
||||
return
|
||||
|
||||
if elem_id == "txt2img_prompt":
|
||||
global_prompt = component
|
||||
elif elem_id == "txt2img_clear_prompt":
|
||||
add_roll_button(global_prompt)
|
||||
elif elem_id == "img2img_prompt":
|
||||
global_prompt = component
|
||||
elif elem_id == "img2img_clear_prompt":
|
||||
add_roll_button(global_prompt)
|
||||
|
||||
|
||||
script_callbacks.on_after_component(after_component)
|
|
@ -97,7 +97,10 @@ titles = {
|
|||
|
||||
"Learning rate": "how fast should the training go. Low values will take longer to train, high values may fail to converge (not generate accurate results) and/or may break the embedding (This has happened if you see Loss: nan in the training info textbox. If this happens, you need to manually restore your embedding from an older not-broken backup).\n\nYou can set a single numeric value, or multiple learning rates using the syntax:\n\n rate_1:max_steps_1, rate_2:max_steps_2, ...\n\nEG: 0.005:100, 1e-3:1000, 1e-5\n\nWill train with rate of 0.005 for first 100 steps, then 1e-3 until 1000 steps, then 1e-5 for all remaining steps.",
|
||||
|
||||
"Clip skip": "Early stopping parameter for CLIP model; 1 is stop at last layer as usual, 2 is stop at penultimate layer, etc."
|
||||
"Clip skip": "Early stopping parameter for CLIP model; 1 is stop at last layer as usual, 2 is stop at penultimate layer, etc.",
|
||||
|
||||
"Approx NN": "Cheap neural network approximation. Very fast compared to VAE, but produces pictures with 4 times smaller horizontal/vertical resoluton and lower quality.",
|
||||
"Approx cheap": "Very cheap approximation. Very fast compared to VAE, but produces pictures with 8 times smaller horizontal/vertical resoluton and extremely low quality."
|
||||
}
|
||||
|
||||
|
||||
|
|
BIN
models/VAE-approx/model.pt
Normal file
BIN
models/VAE-approx/model.pt
Normal file
Binary file not shown.
|
@ -10,13 +10,17 @@ from fastapi.security import HTTPBasic, HTTPBasicCredentials
|
|||
from secrets import compare_digest
|
||||
|
||||
import modules.shared as shared
|
||||
from modules import sd_samplers, deepbooru
|
||||
from modules import sd_samplers, deepbooru, sd_hijack
|
||||
from modules.api.models import *
|
||||
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
|
||||
from modules.extras import run_extras, run_pnginfo
|
||||
from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
|
||||
from modules.textual_inversion.preprocess import preprocess
|
||||
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
|
||||
from PIL import PngImagePlugin,Image
|
||||
from modules.sd_models import checkpoints_list
|
||||
from modules.realesrgan_model import get_realesrgan_models
|
||||
from modules import devices
|
||||
from typing import List
|
||||
|
||||
def upscaler_to_index(name: str):
|
||||
|
@ -97,6 +101,11 @@ class Api:
|
|||
self.add_api_route("/sdapi/v1/artist-categories", self.get_artists_categories, methods=["GET"], response_model=List[str])
|
||||
self.add_api_route("/sdapi/v1/artists", self.get_artists, methods=["GET"], response_model=List[ArtistItem])
|
||||
self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
|
||||
self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=CreateResponse)
|
||||
self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=CreateResponse)
|
||||
self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=PreprocessResponse)
|
||||
self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=TrainResponse)
|
||||
self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=TrainResponse)
|
||||
|
||||
def add_api_route(self, path: str, endpoint, **kwargs):
|
||||
if shared.cmd_opts.api_auth:
|
||||
|
@ -112,7 +121,6 @@ class Api:
|
|||
|
||||
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
|
||||
populate = txt2imgreq.copy(update={ # Override __init__ params
|
||||
"sd_model": shared.sd_model,
|
||||
"sampler_name": validate_sampler_name(txt2imgreq.sampler_name or txt2imgreq.sampler_index),
|
||||
"do_not_save_samples": True,
|
||||
"do_not_save_grid": True
|
||||
|
@ -120,15 +128,14 @@ class Api:
|
|||
)
|
||||
if populate.sampler_name:
|
||||
populate.sampler_index = None # prevent a warning later on
|
||||
p = StableDiffusionProcessingTxt2Img(**vars(populate))
|
||||
# Override object param
|
||||
|
||||
shared.state.begin()
|
||||
|
||||
with self.queue_lock:
|
||||
processed = process_images(p)
|
||||
p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **vars(populate))
|
||||
|
||||
shared.state.begin()
|
||||
processed = process_images(p)
|
||||
shared.state.end()
|
||||
|
||||
shared.state.end()
|
||||
|
||||
b64images = list(map(encode_pil_to_base64, processed.images))
|
||||
|
||||
|
@ -144,7 +151,6 @@ class Api:
|
|||
mask = decode_base64_to_image(mask)
|
||||
|
||||
populate = img2imgreq.copy(update={ # Override __init__ params
|
||||
"sd_model": shared.sd_model,
|
||||
"sampler_name": validate_sampler_name(img2imgreq.sampler_name or img2imgreq.sampler_index),
|
||||
"do_not_save_samples": True,
|
||||
"do_not_save_grid": True,
|
||||
|
@ -156,16 +162,14 @@ class Api:
|
|||
|
||||
args = vars(populate)
|
||||
args.pop('include_init_images', None) # this is meant to be done by "exclude": True in model, but it's for a reason that I cannot determine.
|
||||
p = StableDiffusionProcessingImg2Img(**args)
|
||||
|
||||
p.init_images = [decode_base64_to_image(x) for x in init_images]
|
||||
|
||||
shared.state.begin()
|
||||
|
||||
with self.queue_lock:
|
||||
processed = process_images(p)
|
||||
p = StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)
|
||||
p.init_images = [decode_base64_to_image(x) for x in init_images]
|
||||
|
||||
shared.state.end()
|
||||
shared.state.begin()
|
||||
processed = process_images(p)
|
||||
shared.state.end()
|
||||
|
||||
b64images = list(map(encode_pil_to_base64, processed.images))
|
||||
|
||||
|
@ -326,6 +330,89 @@ class Api:
|
|||
def refresh_checkpoints(self):
|
||||
shared.refresh_checkpoints()
|
||||
|
||||
def create_embedding(self, args: dict):
|
||||
try:
|
||||
shared.state.begin()
|
||||
filename = create_embedding(**args) # create empty embedding
|
||||
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() # reload embeddings so new one can be immediately used
|
||||
shared.state.end()
|
||||
return CreateResponse(info = "create embedding filename: {filename}".format(filename = filename))
|
||||
except AssertionError as e:
|
||||
shared.state.end()
|
||||
return TrainResponse(info = "create embedding error: {error}".format(error = e))
|
||||
|
||||
def create_hypernetwork(self, args: dict):
|
||||
try:
|
||||
shared.state.begin()
|
||||
filename = create_hypernetwork(**args) # create empty embedding
|
||||
shared.state.end()
|
||||
return CreateResponse(info = "create hypernetwork filename: {filename}".format(filename = filename))
|
||||
except AssertionError as e:
|
||||
shared.state.end()
|
||||
return TrainResponse(info = "create hypernetwork error: {error}".format(error = e))
|
||||
|
||||
def preprocess(self, args: dict):
|
||||
try:
|
||||
shared.state.begin()
|
||||
preprocess(**args) # quick operation unless blip/booru interrogation is enabled
|
||||
shared.state.end()
|
||||
return PreprocessResponse(info = 'preprocess complete')
|
||||
except KeyError as e:
|
||||
shared.state.end()
|
||||
return PreprocessResponse(info = "preprocess error: invalid token: {error}".format(error = e))
|
||||
except AssertionError as e:
|
||||
shared.state.end()
|
||||
return PreprocessResponse(info = "preprocess error: {error}".format(error = e))
|
||||
except FileNotFoundError as e:
|
||||
shared.state.end()
|
||||
return PreprocessResponse(info = 'preprocess error: {error}'.format(error = e))
|
||||
|
||||
def train_embedding(self, args: dict):
|
||||
try:
|
||||
shared.state.begin()
|
||||
apply_optimizations = shared.opts.training_xattention_optimizations
|
||||
error = None
|
||||
filename = ''
|
||||
if not apply_optimizations:
|
||||
sd_hijack.undo_optimizations()
|
||||
try:
|
||||
embedding, filename = train_embedding(**args) # can take a long time to complete
|
||||
except Exception as e:
|
||||
error = e
|
||||
finally:
|
||||
if not apply_optimizations:
|
||||
sd_hijack.apply_optimizations()
|
||||
shared.state.end()
|
||||
return TrainResponse(info = "train embedding complete: filename: {filename} error: {error}".format(filename = filename, error = error))
|
||||
except AssertionError as msg:
|
||||
shared.state.end()
|
||||
return TrainResponse(info = "train embedding error: {msg}".format(msg = msg))
|
||||
|
||||
def train_hypernetwork(self, args: dict):
|
||||
try:
|
||||
shared.state.begin()
|
||||
initial_hypernetwork = shared.loaded_hypernetwork
|
||||
apply_optimizations = shared.opts.training_xattention_optimizations
|
||||
error = None
|
||||
filename = ''
|
||||
if not apply_optimizations:
|
||||
sd_hijack.undo_optimizations()
|
||||
try:
|
||||
hypernetwork, filename = train_hypernetwork(*args)
|
||||
except Exception as e:
|
||||
error = e
|
||||
finally:
|
||||
shared.loaded_hypernetwork = initial_hypernetwork
|
||||
shared.sd_model.cond_stage_model.to(devices.device)
|
||||
shared.sd_model.first_stage_model.to(devices.device)
|
||||
if not apply_optimizations:
|
||||
sd_hijack.apply_optimizations()
|
||||
shared.state.end()
|
||||
return TrainResponse(info = "train embedding complete: filename: {filename} error: {error}".format(filename = filename, error = error))
|
||||
except AssertionError as msg:
|
||||
shared.state.end()
|
||||
return TrainResponse(info = "train embedding error: {error}".format(error = error))
|
||||
|
||||
def launch(self, server_name, port):
|
||||
self.app.include_router(self.router)
|
||||
uvicorn.run(self.app, host=server_name, port=port)
|
||||
|
|
|
@ -175,6 +175,15 @@ class InterrogateRequest(BaseModel):
|
|||
class InterrogateResponse(BaseModel):
|
||||
caption: str = Field(default=None, title="Caption", description="The generated caption for the image.")
|
||||
|
||||
class TrainResponse(BaseModel):
|
||||
info: str = Field(title="Train info", description="Response string from train embedding or hypernetwork task.")
|
||||
|
||||
class CreateResponse(BaseModel):
|
||||
info: str = Field(title="Create info", description="Response string from create embedding or hypernetwork task.")
|
||||
|
||||
class PreprocessResponse(BaseModel):
|
||||
info: str = Field(title="Preprocess info", description="Response string from preprocessing task.")
|
||||
|
||||
fields = {}
|
||||
for key, metadata in opts.data_labels.items():
|
||||
value = opts.data.get(key)
|
||||
|
|
|
@ -382,7 +382,7 @@ class VQAutoEncoder(nn.Module):
|
|||
self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
|
||||
logger.info(f'vqgan is loaded from: {model_path} [params]')
|
||||
else:
|
||||
raise ValueError(f'Wrong params!')
|
||||
raise ValueError('Wrong params!')
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
|
@ -431,7 +431,7 @@ class VQGANDiscriminator(nn.Module):
|
|||
elif 'params' in chkpt:
|
||||
self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
|
||||
else:
|
||||
raise ValueError(f'Wrong params!')
|
||||
raise ValueError('Wrong params!')
|
||||
|
||||
def forward(self, x):
|
||||
return self.main(x)
|
|
@ -38,7 +38,7 @@ def quote(text):
|
|||
def image_from_url_text(filedata):
|
||||
if type(filedata) == dict and filedata["is_file"]:
|
||||
filename = filedata["name"]
|
||||
is_in_right_dir = any(Path(temp_dir).resolve() in Path(filename).resolve().parents for temp_dir in shared.demo.temp_dirs)
|
||||
is_in_right_dir = any([filename in fileset for fileset in shared.demo.temp_file_sets])
|
||||
assert is_in_right_dir, 'trying to open image file outside of allowed directories'
|
||||
|
||||
return Image.open(filename)
|
||||
|
|
|
@ -277,7 +277,7 @@ def load_hypernetwork(filename):
|
|||
print(traceback.format_exc(), file=sys.stderr)
|
||||
else:
|
||||
if shared.loaded_hypernetwork is not None:
|
||||
print(f"Unloading hypernetwork")
|
||||
print("Unloading hypernetwork")
|
||||
|
||||
shared.loaded_hypernetwork = None
|
||||
|
||||
|
@ -378,6 +378,32 @@ def report_statistics(loss_info:dict):
|
|||
print(e)
|
||||
|
||||
|
||||
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False):
|
||||
# Remove illegal characters from name.
|
||||
name = "".join( x for x in name if (x.isalnum() or x in "._- "))
|
||||
|
||||
fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
|
||||
if not overwrite_old:
|
||||
assert not os.path.exists(fn), f"file {fn} already exists"
|
||||
|
||||
if type(layer_structure) == str:
|
||||
layer_structure = [float(x.strip()) for x in layer_structure.split(",")]
|
||||
|
||||
hypernet = modules.hypernetworks.hypernetwork.Hypernetwork(
|
||||
name=name,
|
||||
enable_sizes=[int(x) for x in enable_sizes],
|
||||
layer_structure=layer_structure,
|
||||
activation_func=activation_func,
|
||||
weight_init=weight_init,
|
||||
add_layer_norm=add_layer_norm,
|
||||
use_dropout=use_dropout,
|
||||
)
|
||||
hypernet.save(fn)
|
||||
|
||||
shared.reload_hypernetworks()
|
||||
|
||||
return fn
|
||||
|
||||
|
||||
def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
||||
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
|
||||
|
@ -417,7 +443,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
|
|||
|
||||
initial_step = hypernetwork.step or 0
|
||||
if initial_step >= steps:
|
||||
shared.state.textinfo = f"Model has already been trained beyond specified max steps"
|
||||
shared.state.textinfo = "Model has already been trained beyond specified max steps"
|
||||
return hypernetwork, filename
|
||||
|
||||
scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
|
||||
|
|
|
@ -3,39 +3,16 @@ import os
|
|||
import re
|
||||
|
||||
import gradio as gr
|
||||
import modules.textual_inversion.preprocess
|
||||
import modules.textual_inversion.textual_inversion
|
||||
import modules.hypernetworks.hypernetwork
|
||||
from modules import devices, sd_hijack, shared
|
||||
from modules.hypernetworks import hypernetwork
|
||||
|
||||
not_available = ["hardswish", "multiheadattention"]
|
||||
keys = list(x for x in hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available)
|
||||
keys = list(x for x in modules.hypernetworks.hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available)
|
||||
|
||||
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False):
|
||||
# Remove illegal characters from name.
|
||||
name = "".join( x for x in name if (x.isalnum() or x in "._- "))
|
||||
filename = modules.hypernetworks.hypernetwork.create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout)
|
||||
|
||||
fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
|
||||
if not overwrite_old:
|
||||
assert not os.path.exists(fn), f"file {fn} already exists"
|
||||
|
||||
if type(layer_structure) == str:
|
||||
layer_structure = [float(x.strip()) for x in layer_structure.split(",")]
|
||||
|
||||
hypernet = modules.hypernetworks.hypernetwork.Hypernetwork(
|
||||
name=name,
|
||||
enable_sizes=[int(x) for x in enable_sizes],
|
||||
layer_structure=layer_structure,
|
||||
activation_func=activation_func,
|
||||
weight_init=weight_init,
|
||||
add_layer_norm=add_layer_norm,
|
||||
use_dropout=use_dropout,
|
||||
)
|
||||
hypernet.save(fn)
|
||||
|
||||
shared.reload_hypernetworks()
|
||||
|
||||
return gr.Dropdown.update(choices=sorted([x for x in shared.hypernetworks.keys()])), f"Created: {fn}", ""
|
||||
return gr.Dropdown.update(choices=sorted([x for x in shared.hypernetworks.keys()])), f"Created: {filename}", ""
|
||||
|
||||
|
||||
def train_hypernetwork(*args):
|
||||
|
|
|
@ -39,11 +39,14 @@ def image_grid(imgs, batch_size=1, rows=None):
|
|||
|
||||
cols = math.ceil(len(imgs) / rows)
|
||||
|
||||
w, h = imgs[0].size
|
||||
grid = Image.new('RGB', size=(cols * w, rows * h), color='black')
|
||||
params = script_callbacks.ImageGridLoopParams(imgs, cols, rows)
|
||||
script_callbacks.image_grid_callback(params)
|
||||
|
||||
for i, img in enumerate(imgs):
|
||||
grid.paste(img, box=(i % cols * w, i // cols * h))
|
||||
w, h = imgs[0].size
|
||||
grid = Image.new('RGB', size=(params.cols * w, params.rows * h), color='black')
|
||||
|
||||
for i, img in enumerate(params.imgs):
|
||||
grid.paste(img, box=(i % params.cols * w, i // params.cols * h))
|
||||
|
||||
return grid
|
||||
|
||||
|
@ -525,6 +528,9 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
|
|||
image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality, pnginfo=pnginfo_data)
|
||||
|
||||
elif extension.lower() in (".jpg", ".jpeg", ".webp"):
|
||||
if image_to_save.mode == 'RGBA':
|
||||
image_to_save = image_to_save.convert("RGB")
|
||||
|
||||
image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality)
|
||||
|
||||
if opts.enable_pnginfo and info is not None:
|
||||
|
@ -599,7 +605,7 @@ def read_info_from_image(image):
|
|||
Negative prompt: {json_info["uc"]}
|
||||
Steps: {json_info["steps"]}, Sampler: {sampler}, CFG scale: {json_info["scale"]}, Seed: {json_info["seed"]}, Size: {image.width}x{image.height}, Clip skip: 2, ENSD: 31337"""
|
||||
except Exception:
|
||||
print(f"Error parsing NovelAI image generation parameters:", file=sys.stderr)
|
||||
print("Error parsing NovelAI image generation parameters:", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
return geninfo, items
|
||||
|
|
|
@ -162,4 +162,4 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
|
|||
if opts.do_not_show_images:
|
||||
processed.images = []
|
||||
|
||||
return processed.images, generation_info_js, plaintext_to_html(processed.info)
|
||||
return processed.images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments)
|
||||
|
|
|
@ -135,7 +135,7 @@ class InterrogateModels:
|
|||
return caption[0]
|
||||
|
||||
def interrogate(self, pil_image):
|
||||
res = None
|
||||
res = ""
|
||||
|
||||
try:
|
||||
|
||||
|
@ -172,7 +172,7 @@ class InterrogateModels:
|
|||
res += ", " + match
|
||||
|
||||
except Exception:
|
||||
print(f"Error interrogating", file=sys.stderr)
|
||||
print("Error interrogating", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
res += "<error>"
|
||||
|
||||
|
|
|
@ -71,10 +71,13 @@ class MemUsageMonitor(threading.Thread):
|
|||
def read(self):
|
||||
if not self.disabled:
|
||||
free, total = torch.cuda.mem_get_info()
|
||||
self.data["free"] = free
|
||||
self.data["total"] = total
|
||||
|
||||
torch_stats = torch.cuda.memory_stats(self.device)
|
||||
self.data["active"] = torch_stats["active.all.current"]
|
||||
self.data["active_peak"] = torch_stats["active_bytes.all.peak"]
|
||||
self.data["reserved"] = torch_stats["reserved_bytes.all.current"]
|
||||
self.data["reserved_peak"] = torch_stats["reserved_bytes.all.peak"]
|
||||
self.data["system_peak"] = total - self.data["min_free"]
|
||||
|
||||
|
|
|
@ -239,7 +239,7 @@ class StableDiffusionProcessing():
|
|||
|
||||
|
||||
class Processed:
|
||||
def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_negative_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None):
|
||||
def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_negative_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None, comments=""):
|
||||
self.images = images_list
|
||||
self.prompt = p.prompt
|
||||
self.negative_prompt = p.negative_prompt
|
||||
|
@ -247,6 +247,7 @@ class Processed:
|
|||
self.subseed = subseed
|
||||
self.subseed_strength = p.subseed_strength
|
||||
self.info = info
|
||||
self.comments = comments
|
||||
self.width = p.width
|
||||
self.height = p.height
|
||||
self.sampler_name = p.sampler_name
|
||||
|
@ -338,13 +339,14 @@ def slerp(val, low, high):
|
|||
|
||||
|
||||
def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0, p=None):
|
||||
eta_noise_seed_delta = opts.eta_noise_seed_delta or 0
|
||||
xs = []
|
||||
|
||||
# if we have multiple seeds, this means we are working with batch size>1; this then
|
||||
# enables the generation of additional tensors with noise that the sampler will use during its processing.
|
||||
# Using those pre-generated tensors instead of simple torch.randn allows a batch with seeds [100, 101] to
|
||||
# produce the same images as with two batches [100], [101].
|
||||
if p is not None and p.sampler is not None and (len(seeds) > 1 and opts.enable_batch_seeds or opts.eta_noise_seed_delta > 0):
|
||||
if p is not None and p.sampler is not None and (len(seeds) > 1 and opts.enable_batch_seeds or eta_noise_seed_delta > 0):
|
||||
sampler_noises = [[] for _ in range(p.sampler.number_of_needed_noises(p))]
|
||||
else:
|
||||
sampler_noises = None
|
||||
|
@ -384,8 +386,8 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
|
|||
if sampler_noises is not None:
|
||||
cnt = p.sampler.number_of_needed_noises(p)
|
||||
|
||||
if opts.eta_noise_seed_delta > 0:
|
||||
torch.manual_seed(seed + opts.eta_noise_seed_delta)
|
||||
if eta_noise_seed_delta > 0:
|
||||
torch.manual_seed(seed + eta_noise_seed_delta)
|
||||
|
||||
for j in range(cnt):
|
||||
sampler_noises[j].append(devices.randn_without_seed(tuple(noise_shape)))
|
||||
|
@ -645,7 +647,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||
|
||||
devices.torch_gc()
|
||||
|
||||
res = Processed(p, output_images, p.all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], index_of_first_image=index_of_first_image, infotexts=infotexts)
|
||||
res = Processed(p, output_images, p.all_seeds[0], infotext(), comments="".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], index_of_first_image=index_of_first_image, infotexts=infotexts)
|
||||
|
||||
if p.scripts is not None:
|
||||
p.scripts.postprocess(p, res)
|
||||
|
|
|
@ -103,7 +103,7 @@ def check_pt(filename, extra_handler):
|
|||
|
||||
|
||||
def load(filename, *args, **kwargs):
|
||||
return load_with_extra(filename, *args, **kwargs)
|
||||
return load_with_extra(filename, extra_handler=global_extra_handler, *args, **kwargs)
|
||||
|
||||
|
||||
def load_with_extra(filename, extra_handler=None, *args, **kwargs):
|
||||
|
@ -137,19 +137,56 @@ def load_with_extra(filename, extra_handler=None, *args, **kwargs):
|
|||
except pickle.UnpicklingError:
|
||||
print(f"Error verifying pickled file from {filename}:", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
print(f"-----> !!!! The file is most likely corrupted !!!! <-----", file=sys.stderr)
|
||||
print(f"You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n", file=sys.stderr)
|
||||
print("-----> !!!! The file is most likely corrupted !!!! <-----", file=sys.stderr)
|
||||
print("You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n", file=sys.stderr)
|
||||
return None
|
||||
|
||||
except Exception:
|
||||
print(f"Error verifying pickled file from {filename}:", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
print(f"\nThe file may be malicious, so the program is not going to read it.", file=sys.stderr)
|
||||
print(f"You can skip this check with --disable-safe-unpickle commandline argument.\n\n", file=sys.stderr)
|
||||
print("\nThe file may be malicious, so the program is not going to read it.", file=sys.stderr)
|
||||
print("You can skip this check with --disable-safe-unpickle commandline argument.\n\n", file=sys.stderr)
|
||||
return None
|
||||
|
||||
return unsafe_torch_load(filename, *args, **kwargs)
|
||||
|
||||
|
||||
class Extra:
|
||||
"""
|
||||
A class for temporarily setting the global handler for when you can't explicitly call load_with_extra
|
||||
(because it's not your code making the torch.load call). The intended use is like this:
|
||||
|
||||
```
|
||||
import torch
|
||||
from modules import safe
|
||||
|
||||
def handler(module, name):
|
||||
if module == 'torch' and name in ['float64', 'float16']:
|
||||
return getattr(torch, name)
|
||||
|
||||
return None
|
||||
|
||||
with safe.Extra(handler):
|
||||
x = torch.load('model.pt')
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, handler):
|
||||
self.handler = handler
|
||||
|
||||
def __enter__(self):
|
||||
global global_extra_handler
|
||||
|
||||
assert global_extra_handler is None, 'already inside an Extra() block'
|
||||
global_extra_handler = self.handler
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
global global_extra_handler
|
||||
|
||||
global_extra_handler = None
|
||||
|
||||
|
||||
unsafe_torch_load = torch.load
|
||||
torch.load = load
|
||||
global_extra_handler = None
|
||||
|
||||
|
|
|
@ -51,6 +51,13 @@ class UiTrainTabParams:
|
|||
self.txt2img_preview_params = txt2img_preview_params
|
||||
|
||||
|
||||
class ImageGridLoopParams:
|
||||
def __init__(self, imgs, cols, rows):
|
||||
self.imgs = imgs
|
||||
self.cols = cols
|
||||
self.rows = rows
|
||||
|
||||
|
||||
ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"])
|
||||
callback_map = dict(
|
||||
callbacks_app_started=[],
|
||||
|
@ -63,6 +70,7 @@ callback_map = dict(
|
|||
callbacks_cfg_denoiser=[],
|
||||
callbacks_before_component=[],
|
||||
callbacks_after_component=[],
|
||||
callbacks_image_grid=[],
|
||||
)
|
||||
|
||||
|
||||
|
@ -155,6 +163,14 @@ def after_component_callback(component, **kwargs):
|
|||
report_exception(c, 'after_component_callback')
|
||||
|
||||
|
||||
def image_grid_callback(params: ImageGridLoopParams):
|
||||
for c in callback_map['callbacks_image_grid']:
|
||||
try:
|
||||
c.callback(params)
|
||||
except Exception:
|
||||
report_exception(c, 'image_grid')
|
||||
|
||||
|
||||
def add_callback(callbacks, fun):
|
||||
stack = [x for x in inspect.stack() if x.filename != __file__]
|
||||
filename = stack[0].filename if len(stack) > 0 else 'unknown file'
|
||||
|
@ -255,3 +271,11 @@ def on_before_component(callback):
|
|||
def on_after_component(callback):
|
||||
"""register a function to be called after a component is created. See on_before_component for more."""
|
||||
add_callback(callback_map['callbacks_after_component'], callback)
|
||||
|
||||
|
||||
def on_image_grid(callback):
|
||||
"""register a function to be called before making an image grid.
|
||||
The callback is called with one argument:
|
||||
- params: ImageGridLoopParams - parameters to be used for grid creation. Can be modified.
|
||||
"""
|
||||
add_callback(callback_map['callbacks_image_grid'], callback)
|
||||
|
|
|
@ -5,7 +5,7 @@ import modules.textual_inversion.textual_inversion
|
|||
from modules import devices, sd_hijack_optimizations, shared, sd_hijack_checkpoint
|
||||
from modules.hypernetworks import hypernetwork
|
||||
from modules.shared import cmd_opts
|
||||
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet
|
||||
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr
|
||||
|
||||
from modules.sd_hijack_optimizations import invokeAI_mps_available
|
||||
|
||||
|
@ -68,6 +68,7 @@ def fix_checkpoint():
|
|||
ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = sd_hijack_checkpoint.ResBlock_forward
|
||||
ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = sd_hijack_checkpoint.AttentionBlock_forward
|
||||
|
||||
|
||||
class StableDiffusionModelHijack:
|
||||
fixes = None
|
||||
comments = []
|
||||
|
@ -78,17 +79,25 @@ class StableDiffusionModelHijack:
|
|||
embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase(cmd_opts.embeddings_dir)
|
||||
|
||||
def hijack(self, m):
|
||||
if type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenCLIPEmbedder:
|
||||
|
||||
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
|
||||
model_embeddings = m.cond_stage_model.roberta.embeddings
|
||||
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self)
|
||||
m.cond_stage_model = sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords(m.cond_stage_model, self)
|
||||
|
||||
elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenCLIPEmbedder:
|
||||
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
|
||||
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
|
||||
m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
|
||||
|
||||
elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder:
|
||||
m.cond_stage_model.model.token_embedding = EmbeddingsWithFixes(m.cond_stage_model.model.token_embedding, self)
|
||||
m.cond_stage_model = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
|
||||
|
||||
self.clip = m.cond_stage_model
|
||||
|
||||
apply_optimizations()
|
||||
|
||||
self.clip = m.cond_stage_model
|
||||
|
||||
fix_checkpoint()
|
||||
|
||||
def flatten(el):
|
||||
|
@ -101,7 +110,11 @@ class StableDiffusionModelHijack:
|
|||
self.layers = flatten(m)
|
||||
|
||||
def undo_hijack(self, m):
|
||||
if type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords:
|
||||
|
||||
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
|
||||
m.cond_stage_model = m.cond_stage_model.wrapped
|
||||
|
||||
elif type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords:
|
||||
m.cond_stage_model = m.cond_stage_model.wrapped
|
||||
|
||||
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
|
||||
|
@ -129,8 +142,8 @@ class StableDiffusionModelHijack:
|
|||
|
||||
def tokenize(self, text):
|
||||
_, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text])
|
||||
return remade_batch_tokens[0], token_count, sd_hijack_clip.get_target_prompt_token_count(token_count)
|
||||
|
||||
return remade_batch_tokens[0], token_count, sd_hijack_clip.get_target_prompt_token_count(token_count)
|
||||
|
||||
|
||||
class EmbeddingsWithFixes(torch.nn.Module):
|
||||
|
|
|
@ -5,7 +5,6 @@ import torch
|
|||
from modules import prompt_parser, devices
|
||||
from modules.shared import opts
|
||||
|
||||
|
||||
def get_target_prompt_token_count(token_count):
|
||||
return math.ceil(max(token_count, 1) / 75) * 75
|
||||
|
||||
|
@ -254,10 +253,13 @@ class FrozenCLIPEmbedderWithCustomWords(FrozenCLIPEmbedderWithCustomWordsBase):
|
|||
def __init__(self, wrapped, hijack):
|
||||
super().__init__(wrapped, hijack)
|
||||
self.tokenizer = wrapped.tokenizer
|
||||
self.comma_token = [v for k, v in self.tokenizer.get_vocab().items() if k == ',</w>'][0]
|
||||
|
||||
vocab = self.tokenizer.get_vocab()
|
||||
|
||||
self.comma_token = vocab.get(',</w>', None)
|
||||
|
||||
self.token_mults = {}
|
||||
tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k]
|
||||
tokens_with_parens = [(k, v) for k, v in vocab.items() if '(' in k or ')' in k or '[' in k or ']' in k]
|
||||
for text, ident in tokens_with_parens:
|
||||
mult = 1.0
|
||||
for c in text:
|
||||
|
@ -296,6 +298,6 @@ class FrozenCLIPEmbedderWithCustomWords(FrozenCLIPEmbedderWithCustomWordsBase):
|
|||
def encode_embedding_init_text(self, init_text, nvpt):
|
||||
embedding_layer = self.wrapped.transformer.text_model.embeddings
|
||||
ids = self.wrapped.tokenizer(init_text, max_length=nvpt, return_tensors="pt", add_special_tokens=False)["input_ids"]
|
||||
embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0)
|
||||
embedded = embedding_layer.token_embedding.wrapped(ids.to(embedding_layer.token_embedding.wrapped.weight.device)).squeeze(0)
|
||||
|
||||
return embedded
|
||||
|
|
|
@ -178,7 +178,7 @@ def sample_plms(self,
|
|||
# sampling
|
||||
C, H, W = shape
|
||||
size = (batch_size, C, H, W)
|
||||
print(f'Data shape for PLMS sampling is {size}')
|
||||
# print(f'Data shape for PLMS sampling is {size}') # remove unnecessary message
|
||||
|
||||
samples, intermediates = self.plms_sampling(conditioning, size,
|
||||
callback=callback,
|
||||
|
|
34
modules/sd_hijack_xlmr.py
Normal file
34
modules/sd_hijack_xlmr.py
Normal file
|
@ -0,0 +1,34 @@
|
|||
import open_clip.tokenizer
|
||||
import torch
|
||||
|
||||
from modules import sd_hijack_clip, devices
|
||||
from modules.shared import opts
|
||||
|
||||
|
||||
class FrozenXLMREmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords):
|
||||
def __init__(self, wrapped, hijack):
|
||||
super().__init__(wrapped, hijack)
|
||||
|
||||
self.id_start = wrapped.config.bos_token_id
|
||||
self.id_end = wrapped.config.eos_token_id
|
||||
self.id_pad = wrapped.config.pad_token_id
|
||||
|
||||
self.comma_token = self.tokenizer.get_vocab().get(',', None) # alt diffusion doesn't have </w> bits for comma
|
||||
|
||||
def encode_with_transformers(self, tokens):
|
||||
# there's no CLIP Skip here because all hidden layers have size of 1024 and the last one uses a
|
||||
# trained layer to transform those 1024 into 768 for unet; so you can't choose which transformer
|
||||
# layer to work with - you have to use the last
|
||||
|
||||
attention_mask = (tokens != self.id_pad).to(device=tokens.device, dtype=torch.int64)
|
||||
features = self.wrapped(input_ids=tokens, attention_mask=attention_mask)
|
||||
z = features['projection_state']
|
||||
|
||||
return z
|
||||
|
||||
def encode_embedding_init_text(self, init_text, nvpt):
|
||||
embedding_layer = self.wrapped.roberta.embeddings
|
||||
ids = self.wrapped.tokenizer(init_text, max_length=nvpt, return_tensors="pt", add_special_tokens=False)["input_ids"]
|
||||
embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0)
|
||||
|
||||
return embedded
|
|
@ -117,13 +117,13 @@ def select_checkpoint():
|
|||
return checkpoint_info
|
||||
|
||||
if len(checkpoints_list) == 0:
|
||||
print(f"No checkpoints found. When searching for checkpoints, looked at:", file=sys.stderr)
|
||||
print("No checkpoints found. When searching for checkpoints, looked at:", file=sys.stderr)
|
||||
if shared.cmd_opts.ckpt is not None:
|
||||
print(f" - file {os.path.abspath(shared.cmd_opts.ckpt)}", file=sys.stderr)
|
||||
print(f" - directory {model_path}", file=sys.stderr)
|
||||
if shared.cmd_opts.ckpt_dir is not None:
|
||||
print(f" - directory {os.path.abspath(shared.cmd_opts.ckpt_dir)}", file=sys.stderr)
|
||||
print(f"Can't run without a checkpoint. Find and place a .ckpt file into any of those locations. The program will exit.", file=sys.stderr)
|
||||
print("Can't run without a checkpoint. Find and place a .ckpt file into any of those locations. The program will exit.", file=sys.stderr)
|
||||
exit(1)
|
||||
|
||||
checkpoint_info = next(iter(checkpoints_list.values()))
|
||||
|
@ -324,7 +324,10 @@ def load_model(checkpoint_info=None):
|
|||
|
||||
script_callbacks.model_loaded_callback(sd_model)
|
||||
|
||||
print(f"Model loaded.")
|
||||
print("Model loaded.")
|
||||
|
||||
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload = True) # Reload embeddings after model load as they may or may not fit the model
|
||||
|
||||
return sd_model
|
||||
|
||||
|
||||
|
@ -359,5 +362,5 @@ def reload_model_weights(sd_model=None, info=None):
|
|||
if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
|
||||
sd_model.to(devices.device)
|
||||
|
||||
print(f"Weights loaded.")
|
||||
print("Weights loaded.")
|
||||
return sd_model
|
||||
|
|
|
@ -9,7 +9,7 @@ import k_diffusion.sampling
|
|||
import torchsde._brownian.brownian_interval
|
||||
import ldm.models.diffusion.ddim
|
||||
import ldm.models.diffusion.plms
|
||||
from modules import prompt_parser, devices, processing, images
|
||||
from modules import prompt_parser, devices, processing, images, sd_vae_approx
|
||||
|
||||
from modules.shared import opts, cmd_opts, state
|
||||
import modules.shared as shared
|
||||
|
@ -106,28 +106,31 @@ def setup_img2img_steps(p, steps=None):
|
|||
return steps, t_enc
|
||||
|
||||
|
||||
def single_sample_to_image(sample, approximation=False):
|
||||
if approximation:
|
||||
# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/2
|
||||
coefs = torch.tensor(
|
||||
[[ 0.298, 0.207, 0.208],
|
||||
[ 0.187, 0.286, 0.173],
|
||||
[-0.158, 0.189, 0.264],
|
||||
[-0.184, -0.271, -0.473]]).to(sample.device)
|
||||
x_sample = torch.einsum("lxy,lr -> rxy", sample, coefs)
|
||||
approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2}
|
||||
|
||||
|
||||
def single_sample_to_image(sample, approximation=None):
|
||||
if approximation is None:
|
||||
approximation = approximation_indexes.get(opts.show_progress_type, 0)
|
||||
|
||||
if approximation == 2:
|
||||
x_sample = sd_vae_approx.cheap_approximation(sample)
|
||||
elif approximation == 1:
|
||||
x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach()
|
||||
else:
|
||||
x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0]
|
||||
|
||||
x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
|
||||
x_sample = x_sample.astype(np.uint8)
|
||||
return Image.fromarray(x_sample)
|
||||
|
||||
|
||||
def sample_to_image(samples, index=0, approximation=False):
|
||||
def sample_to_image(samples, index=0, approximation=None):
|
||||
return single_sample_to_image(samples[index], approximation)
|
||||
|
||||
|
||||
def samples_to_image_grid(samples, approximation=False):
|
||||
def samples_to_image_grid(samples, approximation=None):
|
||||
return images.image_grid([single_sample_to_image(sample, approximation) for sample in samples])
|
||||
|
||||
|
||||
|
@ -136,7 +139,7 @@ def store_latent(decoded):
|
|||
|
||||
if opts.show_progress_every_n_steps > 0 and shared.state.sampling_step % opts.show_progress_every_n_steps == 0:
|
||||
if not shared.parallel_processing_allowed:
|
||||
shared.state.current_image = sample_to_image(decoded, approximation=opts.show_progress_approximate)
|
||||
shared.state.current_image = sample_to_image(decoded)
|
||||
|
||||
|
||||
class InterruptedException(BaseException):
|
||||
|
@ -462,7 +465,9 @@ class KDiffusionSampler:
|
|||
if p.sampler_noise_scheduler_override:
|
||||
sigmas = p.sampler_noise_scheduler_override(steps)
|
||||
elif self.config is not None and self.config.options.get('scheduler', None) == 'karras':
|
||||
sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=0.1, sigma_max=10, device=shared.device)
|
||||
sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item())
|
||||
|
||||
sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, device=shared.device)
|
||||
else:
|
||||
sigmas = self.model_wrap.get_sigmas(steps)
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import torch
|
||||
import os
|
||||
import collections
|
||||
from collections import namedtuple
|
||||
from modules import shared, devices, script_callbacks
|
||||
from modules.paths import models_path
|
||||
|
@ -30,6 +31,7 @@ base_vae = None
|
|||
loaded_vae_file = None
|
||||
checkpoint_info = None
|
||||
|
||||
checkpoints_loaded = collections.OrderedDict()
|
||||
|
||||
def get_base_vae(model):
|
||||
if base_vae is not None and checkpoint_info == model.sd_checkpoint_info and model:
|
||||
|
@ -149,13 +151,30 @@ def load_vae(model, vae_file=None):
|
|||
global first_load, vae_dict, vae_list, loaded_vae_file
|
||||
# save_settings = False
|
||||
|
||||
cache_enabled = shared.opts.sd_vae_checkpoint_cache > 0
|
||||
|
||||
if vae_file:
|
||||
assert os.path.isfile(vae_file), f"VAE file doesn't exist: {vae_file}"
|
||||
print(f"Loading VAE weights from: {vae_file}")
|
||||
store_base_vae(model)
|
||||
vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location)
|
||||
vae_dict_1 = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys}
|
||||
_load_vae_dict(model, vae_dict_1)
|
||||
if cache_enabled and vae_file in checkpoints_loaded:
|
||||
# use vae checkpoint cache
|
||||
print(f"Loading VAE weights [{get_filename(vae_file)}] from cache")
|
||||
store_base_vae(model)
|
||||
_load_vae_dict(model, checkpoints_loaded[vae_file])
|
||||
else:
|
||||
assert os.path.isfile(vae_file), f"VAE file doesn't exist: {vae_file}"
|
||||
print(f"Loading VAE weights from: {vae_file}")
|
||||
store_base_vae(model)
|
||||
vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location)
|
||||
vae_dict_1 = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys}
|
||||
_load_vae_dict(model, vae_dict_1)
|
||||
|
||||
if cache_enabled:
|
||||
# cache newly loaded vae
|
||||
checkpoints_loaded[vae_file] = vae_dict_1.copy()
|
||||
|
||||
# clean up cache if limit is reached
|
||||
if cache_enabled:
|
||||
while len(checkpoints_loaded) > shared.opts.sd_vae_checkpoint_cache + 1: # we need to count the current model
|
||||
checkpoints_loaded.popitem(last=False) # LRU
|
||||
|
||||
# If vae used is not in dict, update it
|
||||
# It will be removed on refresh though
|
||||
|
@ -208,5 +227,5 @@ def reload_vae_weights(sd_model=None, vae_file="auto"):
|
|||
if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
|
||||
sd_model.to(devices.device)
|
||||
|
||||
print(f"VAE Weights loaded.")
|
||||
print("VAE Weights loaded.")
|
||||
return sd_model
|
||||
|
|
58
modules/sd_vae_approx.py
Normal file
58
modules/sd_vae_approx.py
Normal file
|
@ -0,0 +1,58 @@
|
|||
import os
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from modules import devices, paths
|
||||
|
||||
sd_vae_approx_model = None
|
||||
|
||||
|
||||
class VAEApprox(nn.Module):
|
||||
def __init__(self):
|
||||
super(VAEApprox, self).__init__()
|
||||
self.conv1 = nn.Conv2d(4, 8, (7, 7))
|
||||
self.conv2 = nn.Conv2d(8, 16, (5, 5))
|
||||
self.conv3 = nn.Conv2d(16, 32, (3, 3))
|
||||
self.conv4 = nn.Conv2d(32, 64, (3, 3))
|
||||
self.conv5 = nn.Conv2d(64, 32, (3, 3))
|
||||
self.conv6 = nn.Conv2d(32, 16, (3, 3))
|
||||
self.conv7 = nn.Conv2d(16, 8, (3, 3))
|
||||
self.conv8 = nn.Conv2d(8, 3, (3, 3))
|
||||
|
||||
def forward(self, x):
|
||||
extra = 11
|
||||
x = nn.functional.interpolate(x, (x.shape[2] * 2, x.shape[3] * 2))
|
||||
x = nn.functional.pad(x, (extra, extra, extra, extra))
|
||||
|
||||
for layer in [self.conv1, self.conv2, self.conv3, self.conv4, self.conv5, self.conv6, self.conv7, self.conv8, ]:
|
||||
x = layer(x)
|
||||
x = nn.functional.leaky_relu(x, 0.1)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def model():
|
||||
global sd_vae_approx_model
|
||||
|
||||
if sd_vae_approx_model is None:
|
||||
sd_vae_approx_model = VAEApprox()
|
||||
sd_vae_approx_model.load_state_dict(torch.load(os.path.join(paths.models_path, "VAE-approx", "model.pt")))
|
||||
sd_vae_approx_model.eval()
|
||||
sd_vae_approx_model.to(devices.device, devices.dtype)
|
||||
|
||||
return sd_vae_approx_model
|
||||
|
||||
|
||||
def cheap_approximation(sample):
|
||||
# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/2
|
||||
|
||||
coefs = torch.tensor([
|
||||
[0.298, 0.207, 0.208],
|
||||
[0.187, 0.286, 0.173],
|
||||
[-0.158, 0.189, 0.264],
|
||||
[-0.184, -0.271, -0.473],
|
||||
]).to(sample.device)
|
||||
|
||||
x_sample = torch.einsum("lxy,lr -> rxy", sample, coefs)
|
||||
|
||||
return x_sample
|
|
@ -23,7 +23,7 @@ demo = None
|
|||
sd_model_file = os.path.join(script_path, 'model.ckpt')
|
||||
default_sd_model_file = sd_model_file
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--config", type=str, default=os.path.join(script_path, "v1-inference.yaml"), help="path to config which constructs model",)
|
||||
parser.add_argument("--config", type=str, default=os.path.join(script_path, "configs/v1-inference.yaml"), help="path to config which constructs model",)
|
||||
parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",)
|
||||
parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints")
|
||||
parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))
|
||||
|
@ -168,7 +168,7 @@ class State:
|
|||
def dict(self):
|
||||
obj = {
|
||||
"skipped": self.skipped,
|
||||
"interrupted": self.skipped,
|
||||
"interrupted": self.interrupted,
|
||||
"job": self.job,
|
||||
"job_count": self.job_count,
|
||||
"job_no": self.job_no,
|
||||
|
@ -212,9 +212,9 @@ class State:
|
|||
|
||||
import modules.sd_samplers
|
||||
if opts.show_progress_grid:
|
||||
self.current_image = modules.sd_samplers.samples_to_image_grid(self.current_latent, approximation=opts.show_progress_approximate)
|
||||
self.current_image = modules.sd_samplers.samples_to_image_grid(self.current_latent)
|
||||
else:
|
||||
self.current_image = modules.sd_samplers.sample_to_image(self.current_latent, approximation=opts.show_progress_approximate)
|
||||
self.current_image = modules.sd_samplers.sample_to_image(self.current_latent)
|
||||
|
||||
self.current_image_sampling_step = self.sampling_step
|
||||
|
||||
|
@ -356,6 +356,7 @@ options_templates.update(options_section(('training', "Training"), {
|
|||
options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
||||
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints),
|
||||
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
||||
"sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
||||
"sd_vae": OptionInfo("auto", "SD VAE", gr.Dropdown, lambda: {"choices": sd_vae.vae_list}, refresh=sd_vae.refresh_vae_list),
|
||||
"sd_vae_as_default": OptionInfo(False, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
|
||||
"sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
|
||||
|
@ -367,13 +368,17 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
|||
"img2img_background_color": OptionInfo("#ffffff", "With img2img, fill image's transparent parts with this color.", gr.ColorPicker, {}),
|
||||
"enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."),
|
||||
"enable_emphasis": OptionInfo(True, "Emphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"),
|
||||
"use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."),
|
||||
"enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
|
||||
"comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }),
|
||||
'CLIP_stop_at_last_layers': OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}),
|
||||
"random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('compatibility', "Compatibility"), {
|
||||
"use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."),
|
||||
"use_old_karras_scheduler_sigmas": OptionInfo(False, "Use old karras scheduler sigmas (0.1 to 10)."),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('interrogate', "Interrogate Options"), {
|
||||
"interrogate_keep_models_in_memory": OptionInfo(False, "Interrogate: keep models in VRAM"),
|
||||
"interrogate_use_builtin_artists": OptionInfo(True, "Interrogate: use artists from artists.csv"),
|
||||
|
@ -392,7 +397,7 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"),
|
|||
options_templates.update(options_section(('ui', "User interface"), {
|
||||
"show_progressbar": OptionInfo(True, "Show progressbar"),
|
||||
"show_progress_every_n_steps": OptionInfo(0, "Show image creation progress every N sampling steps. Set to 0 to disable. Set to -1 to show after completion of batch.", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}),
|
||||
"show_progress_approximate": OptionInfo(False, "Calculate small previews using fast linear approximation instead of VAE"),
|
||||
"show_progress_type": OptionInfo("Full", "Image creation progress preview mode", gr.Radio, {"choices": ["Full", "Approx NN", "Approx cheap"]}),
|
||||
"show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"),
|
||||
"return_grid": OptionInfo(True, "Show grid in results for web"),
|
||||
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
|
||||
|
@ -405,6 +410,7 @@ options_templates.update(options_section(('ui', "User interface"), {
|
|||
"js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"),
|
||||
"js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"),
|
||||
"show_progress_in_title": OptionInfo(True, "Show generation progress in window title."),
|
||||
"samplers_in_dropdown": OptionInfo(True, "Use dropdown for sampler selection instead of radio group"),
|
||||
'quicksettings': OptionInfo("sd_model_checkpoint", "Quicksettings list"),
|
||||
'localization': OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)),
|
||||
}))
|
||||
|
|
|
@ -23,6 +23,8 @@ class Embedding:
|
|||
self.vec = vec
|
||||
self.name = name
|
||||
self.step = step
|
||||
self.shape = None
|
||||
self.vectors = 0
|
||||
self.cached_checksum = None
|
||||
self.sd_checkpoint = None
|
||||
self.sd_checkpoint_name = None
|
||||
|
@ -57,8 +59,10 @@ class EmbeddingDatabase:
|
|||
def __init__(self, embeddings_dir):
|
||||
self.ids_lookup = {}
|
||||
self.word_embeddings = {}
|
||||
self.skipped_embeddings = []
|
||||
self.dir_mtime = None
|
||||
self.embeddings_dir = embeddings_dir
|
||||
self.expected_shape = -1
|
||||
|
||||
def register_embedding(self, embedding, model):
|
||||
|
||||
|
@ -75,20 +79,24 @@ class EmbeddingDatabase:
|
|||
|
||||
return embedding
|
||||
|
||||
def load_textual_inversion_embeddings(self):
|
||||
def get_expected_shape(self):
|
||||
vec = shared.sd_model.cond_stage_model.encode_embedding_init_text(",", 1)
|
||||
return vec.shape[1]
|
||||
|
||||
def load_textual_inversion_embeddings(self, force_reload = False):
|
||||
mt = os.path.getmtime(self.embeddings_dir)
|
||||
if self.dir_mtime is not None and mt <= self.dir_mtime:
|
||||
if not force_reload and self.dir_mtime is not None and mt <= self.dir_mtime:
|
||||
return
|
||||
|
||||
self.dir_mtime = mt
|
||||
self.ids_lookup.clear()
|
||||
self.word_embeddings.clear()
|
||||
self.skipped_embeddings = []
|
||||
self.expected_shape = self.get_expected_shape()
|
||||
|
||||
def process_file(path, filename):
|
||||
name = os.path.splitext(filename)[0]
|
||||
|
||||
data = []
|
||||
|
||||
if os.path.splitext(filename.upper())[-1] in ['.PNG', '.WEBP', '.JXL', '.AVIF']:
|
||||
embed_image = Image.open(path)
|
||||
if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text:
|
||||
|
@ -122,7 +130,13 @@ class EmbeddingDatabase:
|
|||
embedding.step = data.get('step', None)
|
||||
embedding.sd_checkpoint = data.get('sd_checkpoint', None)
|
||||
embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
|
||||
self.register_embedding(embedding, shared.sd_model)
|
||||
embedding.vectors = vec.shape[0]
|
||||
embedding.shape = vec.shape[-1]
|
||||
|
||||
if self.expected_shape == -1 or self.expected_shape == embedding.shape:
|
||||
self.register_embedding(embedding, shared.sd_model)
|
||||
else:
|
||||
self.skipped_embeddings.append(name)
|
||||
|
||||
for fn in os.listdir(self.embeddings_dir):
|
||||
try:
|
||||
|
@ -137,8 +151,9 @@ class EmbeddingDatabase:
|
|||
print(traceback.format_exc(), file=sys.stderr)
|
||||
continue
|
||||
|
||||
print(f"Loaded a total of {len(self.word_embeddings)} textual inversion embeddings.")
|
||||
print("Embeddings:", ', '.join(self.word_embeddings.keys()))
|
||||
print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}")
|
||||
if len(self.skipped_embeddings) > 0:
|
||||
print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings)}")
|
||||
|
||||
def find_embedding_at_position(self, tokens, offset):
|
||||
token = tokens[offset]
|
||||
|
@ -263,7 +278,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
|
|||
|
||||
initial_step = embedding.step or 0
|
||||
if initial_step >= steps:
|
||||
shared.state.textinfo = f"Model has already been trained beyond specified max steps"
|
||||
shared.state.textinfo = "Model has already been trained beyond specified max steps"
|
||||
return embedding, filename
|
||||
scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
|
||||
|
||||
|
|
|
@ -59,4 +59,4 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
|
|||
if opts.do_not_show_images:
|
||||
processed.images = []
|
||||
|
||||
return processed.images, generation_info_js, plaintext_to_html(processed.info)
|
||||
return processed.images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments)
|
||||
|
|
140
modules/ui.py
140
modules/ui.py
|
@ -19,7 +19,7 @@ import numpy as np
|
|||
from PIL import Image, PngImagePlugin
|
||||
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
|
||||
|
||||
from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru
|
||||
from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, ui_components
|
||||
from modules.paths import script_path
|
||||
|
||||
from modules.shared import opts, cmd_opts, restricted_opts
|
||||
|
@ -80,7 +80,6 @@ css_hide_progressbar = """
|
|||
# Important that they exactly match script.js for tooltip to work.
|
||||
random_symbol = '\U0001f3b2\ufe0f' # 🎲️
|
||||
reuse_symbol = '\u267b\ufe0f' # ♻️
|
||||
art_symbol = '\U0001f3a8' # 🎨
|
||||
paste_symbol = '\u2199\ufe0f' # ↙
|
||||
folder_symbol = '\U0001f4c2' # 📂
|
||||
refresh_symbol = '\U0001f504' # 🔄
|
||||
|
@ -159,7 +158,7 @@ def save_files(js_data, images, do_make_zip, index):
|
|||
zip_file.writestr(filenames[i], f.read())
|
||||
fullfns.insert(0, zip_filepath)
|
||||
|
||||
return gr.File.update(value=fullfns, visible=True), '', '', plaintext_to_html(f"Saved: {filenames[0]}")
|
||||
return gr.File.update(value=fullfns, visible=True), plaintext_to_html(f"Saved: {filenames[0]}")
|
||||
|
||||
|
||||
|
||||
|
@ -234,13 +233,6 @@ def check_progress_call_initial(id_part):
|
|||
return check_progress_call(id_part)
|
||||
|
||||
|
||||
def roll_artist(prompt):
|
||||
allowed_cats = set([x for x in shared.artist_db.categories() if len(opts.random_artist_categories)==0 or x in opts.random_artist_categories])
|
||||
artist = random.choice([x for x in shared.artist_db.artists if x.category in allowed_cats])
|
||||
|
||||
return prompt + ", " + artist.name if prompt != '' else artist.name
|
||||
|
||||
|
||||
def visit(x, func, path=""):
|
||||
if hasattr(x, 'children'):
|
||||
for c in x.children:
|
||||
|
@ -270,7 +262,7 @@ def apply_styles(prompt, prompt_neg, style1_name, style2_name):
|
|||
|
||||
|
||||
def interrogate(image):
|
||||
prompt = shared.interrogator.interrogate(image)
|
||||
prompt = shared.interrogator.interrogate(image.convert("RGB"))
|
||||
|
||||
return gr_show(True) if prompt is None else prompt
|
||||
|
||||
|
@ -403,7 +395,6 @@ def create_toprow(is_img2img):
|
|||
)
|
||||
|
||||
with gr.Column(scale=1, elem_id="roll_col"):
|
||||
roll = gr.Button(value=art_symbol, elem_id="roll", visible=len(shared.artist_db.artists) > 0)
|
||||
paste = gr.Button(value=paste_symbol, elem_id="paste")
|
||||
save_style = gr.Button(value=save_style_symbol, elem_id="style_create")
|
||||
prompt_style_apply = gr.Button(value=apply_style_symbol, elem_id="style_apply")
|
||||
|
@ -452,7 +443,7 @@ def create_toprow(is_img2img):
|
|||
prompt_style2 = gr.Dropdown(label="Style 2", elem_id=f"{id_part}_style2_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())))
|
||||
prompt_style2.save_to_config = True
|
||||
|
||||
return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button
|
||||
return prompt, prompt_style, negative_prompt, prompt_style2, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button
|
||||
|
||||
|
||||
def setup_progressbar(progressbar, preview, id_part, textinfo=None):
|
||||
|
@ -532,7 +523,7 @@ def create_refresh_button(refresh_component, refresh_method, refreshed_args, ele
|
|||
|
||||
return gr.update(**(args or {}))
|
||||
|
||||
refresh_button = gr.Button(value=refresh_symbol, elem_id=elem_id)
|
||||
refresh_button = ui_components.ToolButton(value=refresh_symbol, elem_id=elem_id)
|
||||
refresh_button.click(
|
||||
fn=refresh,
|
||||
inputs=[],
|
||||
|
@ -570,13 +561,14 @@ Requested path was: {f}
|
|||
|
||||
generation_info = None
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
with gr.Row(elem_id=f"image_buttons_{tabname}"):
|
||||
open_folder_button = gr.Button(folder_symbol, elem_id="hidden_element" if shared.cmd_opts.hide_ui_dir_config else 'open_folder')
|
||||
|
||||
if tabname != "extras":
|
||||
save = gr.Button('Save', elem_id=f'save_{tabname}')
|
||||
save_zip = gr.Button('Zip', elem_id=f'save_zip_{tabname}')
|
||||
|
||||
buttons = parameters_copypaste.create_buttons(["img2img", "inpaint", "extras"])
|
||||
button_id = "hidden_element" if shared.cmd_opts.hide_ui_dir_config else 'open_folder'
|
||||
open_folder_button = gr.Button(folder_symbol, elem_id=button_id)
|
||||
|
||||
open_folder_button.click(
|
||||
fn=lambda: open_folder(opts.outdir_samples or outdir),
|
||||
|
@ -585,14 +577,13 @@ Requested path was: {f}
|
|||
)
|
||||
|
||||
if tabname != "extras":
|
||||
with gr.Row():
|
||||
do_make_zip = gr.Checkbox(label="Make Zip when Save?", value=False)
|
||||
|
||||
with gr.Row():
|
||||
download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False)
|
||||
|
||||
with gr.Group():
|
||||
html_info = gr.HTML()
|
||||
html_log = gr.HTML()
|
||||
|
||||
generation_info = gr.Textbox(visible=False)
|
||||
if tabname == 'txt2img' or tabname == 'img2img':
|
||||
generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button")
|
||||
|
@ -606,25 +597,54 @@ Requested path was: {f}
|
|||
|
||||
save.click(
|
||||
fn=wrap_gradio_call(save_files),
|
||||
_js="(x, y, z, w) => [x, y, z, selected_gallery_index()]",
|
||||
_js="(x, y, z, w) => [x, y, false, selected_gallery_index()]",
|
||||
inputs=[
|
||||
generation_info,
|
||||
result_gallery,
|
||||
do_make_zip,
|
||||
html_info,
|
||||
html_info,
|
||||
],
|
||||
outputs=[
|
||||
download_files,
|
||||
html_info,
|
||||
html_info,
|
||||
html_info,
|
||||
html_log,
|
||||
]
|
||||
)
|
||||
|
||||
save_zip.click(
|
||||
fn=wrap_gradio_call(save_files),
|
||||
_js="(x, y, z, w) => [x, y, true, selected_gallery_index()]",
|
||||
inputs=[
|
||||
generation_info,
|
||||
result_gallery,
|
||||
html_info,
|
||||
html_info,
|
||||
],
|
||||
outputs=[
|
||||
download_files,
|
||||
html_log,
|
||||
]
|
||||
)
|
||||
|
||||
else:
|
||||
html_info_x = gr.HTML()
|
||||
html_info = gr.HTML()
|
||||
html_log = gr.HTML()
|
||||
|
||||
parameters_copypaste.bind_buttons(buttons, result_gallery, "txt2img" if tabname == "txt2img" else None)
|
||||
return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info
|
||||
return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log
|
||||
|
||||
|
||||
def create_sampler_and_steps_selection(choices, tabname):
|
||||
if opts.samplers_in_dropdown:
|
||||
with gr.Row(elem_id=f"sampler_selection_{tabname}"):
|
||||
sampler_index = gr.Dropdown(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index")
|
||||
steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling Steps", value=20)
|
||||
else:
|
||||
with gr.Group(elem_id=f"sampler_selection_{tabname}"):
|
||||
steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling Steps", value=20)
|
||||
sampler_index = gr.Radio(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index")
|
||||
|
||||
return steps, sampler_index
|
||||
|
||||
|
||||
def create_ui():
|
||||
|
@ -639,14 +659,11 @@ def create_ui():
|
|||
modules.scripts.scripts_txt2img.initialize_scripts(is_img2img=False)
|
||||
|
||||
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
|
||||
txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _,txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False)
|
||||
txt2img_prompt, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _,txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False)
|
||||
|
||||
dummy_component = gr.Label(visible=False)
|
||||
txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="bytes", visible=False)
|
||||
|
||||
|
||||
|
||||
|
||||
with gr.Row(elem_id='txt2img_progress_row'):
|
||||
with gr.Column(scale=1):
|
||||
pass
|
||||
|
@ -657,9 +674,8 @@ def create_ui():
|
|||
setup_progressbar(progressbar, txt2img_preview, 'txt2img')
|
||||
|
||||
with gr.Row().style(equal_height=False):
|
||||
with gr.Column(variant='panel'):
|
||||
steps = gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=20)
|
||||
sampler_index = gr.Radio(label='Sampling method', elem_id="txt2img_sampling", choices=[x.name for x in samplers], value=samplers[0].name, type="index")
|
||||
with gr.Column(variant='panel', elem_id="txt2img_settings"):
|
||||
steps, sampler_index = create_sampler_and_steps_selection(samplers, "txt2img")
|
||||
|
||||
with gr.Group():
|
||||
width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512)
|
||||
|
@ -686,14 +702,14 @@ def create_ui():
|
|||
with gr.Group():
|
||||
custom_inputs = modules.scripts.scripts_txt2img.setup_ui()
|
||||
|
||||
txt2img_gallery, generation_info, html_info = create_output_panel("txt2img", opts.outdir_txt2img_samples)
|
||||
txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples)
|
||||
parameters_copypaste.bind_buttons({"txt2img": txt2img_paste}, None, txt2img_prompt)
|
||||
|
||||
connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
|
||||
connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
|
||||
|
||||
txt2img_args = dict(
|
||||
fn=wrap_gradio_gpu_call(modules.txt2img.txt2img),
|
||||
fn=wrap_gradio_gpu_call(modules.txt2img.txt2img, extra_outputs=[None, '', '']),
|
||||
_js="submit",
|
||||
inputs=[
|
||||
txt2img_prompt,
|
||||
|
@ -720,7 +736,8 @@ def create_ui():
|
|||
outputs=[
|
||||
txt2img_gallery,
|
||||
generation_info,
|
||||
html_info
|
||||
html_info,
|
||||
html_log,
|
||||
],
|
||||
show_progress=False,
|
||||
)
|
||||
|
@ -745,16 +762,6 @@ def create_ui():
|
|||
outputs=[hr_options],
|
||||
)
|
||||
|
||||
roll.click(
|
||||
fn=roll_artist,
|
||||
_js="update_txt2img_tokens",
|
||||
inputs=[
|
||||
txt2img_prompt,
|
||||
],
|
||||
outputs=[
|
||||
txt2img_prompt,
|
||||
]
|
||||
)
|
||||
|
||||
txt2img_paste_fields = [
|
||||
(txt2img_prompt, "Prompt"),
|
||||
|
@ -797,8 +804,7 @@ def create_ui():
|
|||
modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True)
|
||||
|
||||
with gr.Blocks(analytics_enabled=False) as img2img_interface:
|
||||
img2img_prompt, roll, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste,token_counter, token_button = create_toprow(is_img2img=True)
|
||||
|
||||
img2img_prompt, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste,token_counter, token_button = create_toprow(is_img2img=True)
|
||||
|
||||
with gr.Row(elem_id='img2img_progress_row'):
|
||||
img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="bytes", visible=False)
|
||||
|
@ -812,7 +818,7 @@ def create_ui():
|
|||
setup_progressbar(progressbar, img2img_preview, 'img2img')
|
||||
|
||||
with gr.Row().style(equal_height=False):
|
||||
with gr.Column(variant='panel'):
|
||||
with gr.Column(variant='panel', elem_id="img2img_settings"):
|
||||
|
||||
with gr.Tabs(elem_id="mode_img2img") as tabs_img2img_mode:
|
||||
with gr.TabItem('img2img', id='img2img'):
|
||||
|
@ -859,8 +865,7 @@ def create_ui():
|
|||
with gr.Row():
|
||||
resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", show_label=False, choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize")
|
||||
|
||||
steps = gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=20)
|
||||
sampler_index = gr.Radio(label='Sampling method', choices=[x.name for x in samplers_for_img2img], value=samplers_for_img2img[0].name, type="index")
|
||||
steps, sampler_index = create_sampler_and_steps_selection(samplers_for_img2img, "img2img")
|
||||
|
||||
with gr.Group():
|
||||
width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width")
|
||||
|
@ -883,7 +888,7 @@ def create_ui():
|
|||
with gr.Group():
|
||||
custom_inputs = modules.scripts.scripts_img2img.setup_ui()
|
||||
|
||||
img2img_gallery, generation_info, html_info = create_output_panel("img2img", opts.outdir_img2img_samples)
|
||||
img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples)
|
||||
parameters_copypaste.bind_buttons({"img2img": img2img_paste}, None, img2img_prompt)
|
||||
|
||||
connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
|
||||
|
@ -915,7 +920,7 @@ def create_ui():
|
|||
)
|
||||
|
||||
img2img_args = dict(
|
||||
fn=wrap_gradio_gpu_call(modules.img2img.img2img),
|
||||
fn=wrap_gradio_gpu_call(modules.img2img.img2img, extra_outputs=[None, '', '']),
|
||||
_js="submit_img2img",
|
||||
inputs=[
|
||||
dummy_component,
|
||||
|
@ -954,7 +959,8 @@ def create_ui():
|
|||
outputs=[
|
||||
img2img_gallery,
|
||||
generation_info,
|
||||
html_info
|
||||
html_info,
|
||||
html_log,
|
||||
],
|
||||
show_progress=False,
|
||||
)
|
||||
|
@ -974,18 +980,6 @@ def create_ui():
|
|||
outputs=[img2img_prompt],
|
||||
)
|
||||
|
||||
|
||||
roll.click(
|
||||
fn=roll_artist,
|
||||
_js="update_img2img_tokens",
|
||||
inputs=[
|
||||
img2img_prompt,
|
||||
],
|
||||
outputs=[
|
||||
img2img_prompt,
|
||||
]
|
||||
)
|
||||
|
||||
prompts = [(txt2img_prompt, txt2img_negative_prompt), (img2img_prompt, img2img_negative_prompt)]
|
||||
style_dropdowns = [(txt2img_prompt_style, txt2img_prompt_style2), (img2img_prompt_style, img2img_prompt_style2)]
|
||||
style_js_funcs = ["update_txt2img_tokens", "update_img2img_tokens"]
|
||||
|
@ -1078,10 +1072,10 @@ def create_ui():
|
|||
with gr.Group():
|
||||
upscale_before_face_fix = gr.Checkbox(label='Upscale Before Restoring Faces', value=False)
|
||||
|
||||
result_images, html_info_x, html_info = create_output_panel("extras", opts.outdir_extras_samples)
|
||||
result_images, html_info_x, html_info, html_log = create_output_panel("extras", opts.outdir_extras_samples)
|
||||
|
||||
submit.click(
|
||||
fn=wrap_gradio_gpu_call(modules.extras.run_extras),
|
||||
fn=wrap_gradio_gpu_call(modules.extras.run_extras, extra_outputs=[None, '']),
|
||||
_js="get_extras_tab_index",
|
||||
inputs=[
|
||||
dummy_component,
|
||||
|
@ -1142,8 +1136,14 @@ def create_ui():
|
|||
|
||||
with gr.Row():
|
||||
primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)")
|
||||
create_refresh_button(primary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_A")
|
||||
|
||||
secondary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary model (B)")
|
||||
create_refresh_button(secondary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_B")
|
||||
|
||||
tertiary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_tertiary_model_name", label="Tertiary model (C)")
|
||||
create_refresh_button(tertiary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_C")
|
||||
|
||||
custom_name = gr.Textbox(label="Custom Name (Optional)")
|
||||
interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3)
|
||||
interp_method = gr.Radio(choices=["Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method")
|
||||
|
@ -1157,8 +1157,6 @@ def create_ui():
|
|||
with gr.Column(variant='panel'):
|
||||
submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False)
|
||||
|
||||
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
|
||||
|
||||
with gr.Blocks(analytics_enabled=False) as train_interface:
|
||||
with gr.Row().style(equal_height=False):
|
||||
gr.HTML(value="<p style='margin-bottom: 0.7em'>See <b><a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\">wiki</a></b> for detailed explanation.</p>")
|
||||
|
@ -1447,7 +1445,7 @@ def create_ui():
|
|||
res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
|
||||
create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key)
|
||||
else:
|
||||
with gr.Row(variant="compact"):
|
||||
with ui_components.FormRow():
|
||||
res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
|
||||
create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key)
|
||||
else:
|
||||
|
|
18
modules/ui_components.py
Normal file
18
modules/ui_components.py
Normal file
|
@ -0,0 +1,18 @@
|
|||
import gradio as gr
|
||||
|
||||
|
||||
class ToolButton(gr.Button, gr.components.FormComponent):
|
||||
"""Small button with single emoji as text, fits inside gradio forms"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(variant="tool", **kwargs)
|
||||
|
||||
def get_block_name(self):
|
||||
return "button"
|
||||
|
||||
|
||||
class FormRow(gr.Row, gr.components.FormComponent):
|
||||
"""Same as gr.Row but fits inside gradio forms"""
|
||||
|
||||
def get_block_name(self):
|
||||
return "row"
|
|
@ -15,7 +15,8 @@ Savedfile = namedtuple("Savedfile", ["name"])
|
|||
def save_pil_to_file(pil_image, dir=None):
|
||||
already_saved_as = getattr(pil_image, 'already_saved_as', None)
|
||||
if already_saved_as and os.path.isfile(already_saved_as):
|
||||
shared.demo.temp_dirs = shared.demo.temp_dirs | {os.path.abspath(os.path.dirname(already_saved_as))}
|
||||
shared.demo.temp_file_sets[0] = shared.demo.temp_file_sets[0] | {os.path.abspath(already_saved_as)}
|
||||
|
||||
file_obj = Savedfile(already_saved_as)
|
||||
return file_obj
|
||||
|
||||
|
@ -44,7 +45,7 @@ def on_tmpdir_changed():
|
|||
|
||||
os.makedirs(shared.opts.temp_dir, exist_ok=True)
|
||||
|
||||
shared.demo.temp_dirs = shared.demo.temp_dirs | {os.path.abspath(shared.opts.temp_dir)}
|
||||
shared.demo.temp_file_sets[0] = shared.demo.temp_file_sets[0] | {os.path.abspath(shared.opts.temp_dir)}
|
||||
|
||||
|
||||
def cleanup_tmpdr():
|
||||
|
|
137
modules/xlmr.py
Normal file
137
modules/xlmr.py
Normal file
|
@ -0,0 +1,137 @@
|
|||
from transformers import BertPreTrainedModel,BertModel,BertConfig
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRobertaConfig
|
||||
from transformers import XLMRobertaModel,XLMRobertaTokenizer
|
||||
from typing import Optional
|
||||
|
||||
class BertSeriesConfig(BertConfig):
|
||||
def __init__(self, vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, position_embedding_type="absolute", use_cache=True, classifier_dropout=None,project_dim=512, pooler_fn="average",learn_encoder=False,model_type='bert',**kwargs):
|
||||
|
||||
super().__init__(vocab_size, hidden_size, num_hidden_layers, num_attention_heads, intermediate_size, hidden_act, hidden_dropout_prob, attention_probs_dropout_prob, max_position_embeddings, type_vocab_size, initializer_range, layer_norm_eps, pad_token_id, position_embedding_type, use_cache, classifier_dropout, **kwargs)
|
||||
self.project_dim = project_dim
|
||||
self.pooler_fn = pooler_fn
|
||||
self.learn_encoder = learn_encoder
|
||||
|
||||
class RobertaSeriesConfig(XLMRobertaConfig):
|
||||
def __init__(self, pad_token_id=1, bos_token_id=0, eos_token_id=2,project_dim=512,pooler_fn='cls',learn_encoder=False, **kwargs):
|
||||
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
||||
self.project_dim = project_dim
|
||||
self.pooler_fn = pooler_fn
|
||||
self.learn_encoder = learn_encoder
|
||||
|
||||
|
||||
class BertSeriesModelWithTransformation(BertPreTrainedModel):
|
||||
|
||||
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
||||
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
|
||||
config_class = BertSeriesConfig
|
||||
|
||||
def __init__(self, config=None, **kargs):
|
||||
# modify initialization for autoloading
|
||||
if config is None:
|
||||
config = XLMRobertaConfig()
|
||||
config.attention_probs_dropout_prob= 0.1
|
||||
config.bos_token_id=0
|
||||
config.eos_token_id=2
|
||||
config.hidden_act='gelu'
|
||||
config.hidden_dropout_prob=0.1
|
||||
config.hidden_size=1024
|
||||
config.initializer_range=0.02
|
||||
config.intermediate_size=4096
|
||||
config.layer_norm_eps=1e-05
|
||||
config.max_position_embeddings=514
|
||||
|
||||
config.num_attention_heads=16
|
||||
config.num_hidden_layers=24
|
||||
config.output_past=True
|
||||
config.pad_token_id=1
|
||||
config.position_embedding_type= "absolute"
|
||||
|
||||
config.type_vocab_size= 1
|
||||
config.use_cache=True
|
||||
config.vocab_size= 250002
|
||||
config.project_dim = 768
|
||||
config.learn_encoder = False
|
||||
super().__init__(config)
|
||||
self.roberta = XLMRobertaModel(config)
|
||||
self.transformation = nn.Linear(config.hidden_size,config.project_dim)
|
||||
self.pre_LN=nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-large')
|
||||
self.pooler = lambda x: x[:,0]
|
||||
self.post_init()
|
||||
|
||||
def encode(self,c):
|
||||
device = next(self.parameters()).device
|
||||
text = self.tokenizer(c,
|
||||
truncation=True,
|
||||
max_length=77,
|
||||
return_length=False,
|
||||
return_overflowing_tokens=False,
|
||||
padding="max_length",
|
||||
return_tensors="pt")
|
||||
text["input_ids"] = torch.tensor(text["input_ids"]).to(device)
|
||||
text["attention_mask"] = torch.tensor(
|
||||
text['attention_mask']).to(device)
|
||||
features = self(**text)
|
||||
return features['projection_state']
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
token_type_ids: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
) :
|
||||
r"""
|
||||
"""
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
|
||||
outputs = self.roberta(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=True,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
# last module outputs
|
||||
sequence_output = outputs[0]
|
||||
|
||||
|
||||
# project every module
|
||||
sequence_output_ln = self.pre_LN(sequence_output)
|
||||
|
||||
# pooler
|
||||
pooler_output = self.pooler(sequence_output_ln)
|
||||
pooler_output = self.transformation(pooler_output)
|
||||
projection_state = self.transformation(outputs.last_hidden_state)
|
||||
|
||||
return {
|
||||
'pooler_output':pooler_output,
|
||||
'last_hidden_state':outputs.last_hidden_state,
|
||||
'hidden_states':outputs.hidden_states,
|
||||
'attentions':outputs.attentions,
|
||||
'projection_state':projection_state,
|
||||
'sequence_out': sequence_output
|
||||
}
|
||||
|
||||
|
||||
class RobertaSeriesModelWithTransformation(BertSeriesModelWithTransformation):
|
||||
base_model_prefix = 'roberta'
|
||||
config_class= RobertaSeriesConfig
|
|
@ -5,7 +5,7 @@ fairscale==0.4.4
|
|||
fonts
|
||||
font-roboto
|
||||
gfpgan
|
||||
gradio==3.9
|
||||
gradio==3.15.0
|
||||
invisible-watermark
|
||||
numpy
|
||||
omegaconf
|
||||
|
|
|
@ -3,7 +3,7 @@ transformers==4.19.2
|
|||
accelerate==0.12.0
|
||||
basicsr==1.4.2
|
||||
gfpgan==1.3.8
|
||||
gradio==3.9
|
||||
gradio==3.15.0
|
||||
numpy==1.23.3
|
||||
Pillow==9.2.0
|
||||
realesrgan==0.3.0
|
||||
|
|
|
@ -140,7 +140,7 @@ class Script(scripts.Script):
|
|||
try:
|
||||
args = cmdargs(line)
|
||||
except Exception:
|
||||
print(f"Error parsing line [line] as commandline:", file=sys.stderr)
|
||||
print(f"Error parsing line {line} as commandline:", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
args = {"prompt": line}
|
||||
else:
|
||||
|
|
|
@ -35,8 +35,9 @@ class Script(scripts.Script):
|
|||
seed = p.seed
|
||||
|
||||
init_img = p.init_images[0]
|
||||
init_img = images.flatten(init_img, opts.img2img_background_color)
|
||||
|
||||
if (upscaler.name != "None"):
|
||||
if upscaler.name != "None":
|
||||
img = upscaler.scaler.upscale(init_img, scale_factor, upscaler.data_path)
|
||||
else:
|
||||
img = init_img
|
||||
|
|
39
style.css
39
style.css
|
@ -245,11 +245,6 @@ input[type="range"]{
|
|||
margin: 0.5em 0 -0.3em 0;
|
||||
}
|
||||
|
||||
#txt2img_sampling label{
|
||||
padding-left: 0.6em;
|
||||
padding-right: 0.6em;
|
||||
}
|
||||
|
||||
#mask_bug_info {
|
||||
text-align: center;
|
||||
display: block;
|
||||
|
@ -501,13 +496,6 @@ input[type="range"]{
|
|||
padding: 0;
|
||||
}
|
||||
|
||||
#refresh_sd_model_checkpoint, #refresh_sd_vae, #refresh_sd_hypernetwork, #refresh_train_hypernetwork_name, #refresh_train_embedding_name, #refresh_localization{
|
||||
max-width: 2.5em;
|
||||
min-width: 2.5em;
|
||||
height: 2.4em;
|
||||
}
|
||||
|
||||
|
||||
canvas[key="mask"] {
|
||||
z-index: 12 !important;
|
||||
filter: invert();
|
||||
|
@ -568,6 +556,33 @@ img2maskimg, #img2maskimg > .h-60, #img2maskimg > .h-60 > div, #img2maskimg > .h
|
|||
font-size: 95%;
|
||||
}
|
||||
|
||||
#image_buttons_txt2img button, #image_buttons_img2img button, #image_buttons_extras button{
|
||||
min-width: auto;
|
||||
padding-left: 0.5em;
|
||||
padding-right: 0.5em;
|
||||
}
|
||||
|
||||
.gr-form{
|
||||
background-color: white;
|
||||
}
|
||||
|
||||
.dark .gr-form{
|
||||
background-color: rgb(31 41 55 / var(--tw-bg-opacity));
|
||||
}
|
||||
|
||||
.gr-button-tool{
|
||||
max-width: 2.5em;
|
||||
min-width: 2.5em !important;
|
||||
height: 2.4em;
|
||||
margin: 0.55em 0;
|
||||
}
|
||||
|
||||
#quicksettings .gr-button-tool{
|
||||
margin: 0;
|
||||
}
|
||||
|
||||
|
||||
|
||||
/* The following handles localization for right-to-left (RTL) languages like Arabic.
|
||||
The rtl media type will only be activated by the logic in javascript/localization.js.
|
||||
If you change anything above, you need to make sure it is RTL compliant by just running
|
||||
|
|
68
v2-inference-v.yaml
Normal file
68
v2-inference-v.yaml
Normal file
|
@ -0,0 +1,68 @@
|
|||
model:
|
||||
base_learning_rate: 1.0e-4
|
||||
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||
params:
|
||||
parameterization: "v"
|
||||
linear_start: 0.00085
|
||||
linear_end: 0.0120
|
||||
num_timesteps_cond: 1
|
||||
log_every_t: 200
|
||||
timesteps: 1000
|
||||
first_stage_key: "jpg"
|
||||
cond_stage_key: "txt"
|
||||
image_size: 64
|
||||
channels: 4
|
||||
cond_stage_trainable: false
|
||||
conditioning_key: crossattn
|
||||
monitor: val/loss_simple_ema
|
||||
scale_factor: 0.18215
|
||||
use_ema: False # we set this to false because this is an inference only config
|
||||
|
||||
unet_config:
|
||||
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||
params:
|
||||
use_checkpoint: True
|
||||
use_fp16: True
|
||||
image_size: 32 # unused
|
||||
in_channels: 4
|
||||
out_channels: 4
|
||||
model_channels: 320
|
||||
attention_resolutions: [ 4, 2, 1 ]
|
||||
num_res_blocks: 2
|
||||
channel_mult: [ 1, 2, 4, 4 ]
|
||||
num_head_channels: 64 # need to fix for flash-attn
|
||||
use_spatial_transformer: True
|
||||
use_linear_in_transformer: True
|
||||
transformer_depth: 1
|
||||
context_dim: 1024
|
||||
legacy: False
|
||||
|
||||
first_stage_config:
|
||||
target: ldm.models.autoencoder.AutoencoderKL
|
||||
params:
|
||||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
ddconfig:
|
||||
#attn_type: "vanilla-xformers"
|
||||
double_z: true
|
||||
z_channels: 4
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult:
|
||||
- 1
|
||||
- 2
|
||||
- 4
|
||||
- 4
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: []
|
||||
dropout: 0.0
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
|
||||
cond_stage_config:
|
||||
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
|
||||
params:
|
||||
freeze: True
|
||||
layer: "penultimate"
|
Loading…
Reference in New Issue
Block a user