fixes related to merge

This commit is contained in:
AUTOMATIC 2022-10-11 14:53:02 +03:00
parent 5de806184f
commit 530103b586
9 changed files with 82 additions and 164 deletions

View File

@ -1,103 +0,0 @@
import glob
import os
import sys
import traceback
import torch
from ldm.util import default
from modules import devices, shared
import torch
from torch import einsum
from einops import rearrange, repeat
class HypernetworkModule(torch.nn.Module):
def __init__(self, dim, state_dict):
super().__init__()
self.linear1 = torch.nn.Linear(dim, dim * 2)
self.linear2 = torch.nn.Linear(dim * 2, dim)
self.load_state_dict(state_dict, strict=True)
self.to(devices.device)
def forward(self, x):
return x + (self.linear2(self.linear1(x)))
class Hypernetwork:
filename = None
name = None
def __init__(self, filename):
self.filename = filename
self.name = os.path.splitext(os.path.basename(filename))[0]
self.layers = {}
state_dict = torch.load(filename, map_location='cpu')
for size, sd in state_dict.items():
self.layers[size] = (HypernetworkModule(size, sd[0]), HypernetworkModule(size, sd[1]))
def list_hypernetworks(path):
res = {}
for filename in glob.iglob(os.path.join(path, '**/*.pt'), recursive=True):
name = os.path.splitext(os.path.basename(filename))[0]
res[name] = filename
return res
def load_hypernetwork(filename):
path = shared.hypernetworks.get(filename, None)
if path is not None:
print(f"Loading hypernetwork {filename}")
try:
shared.loaded_hypernetwork = Hypernetwork(path)
except Exception:
print(f"Error loading hypernetwork {path}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
else:
if shared.loaded_hypernetwork is not None:
print(f"Unloading hypernetwork")
shared.loaded_hypernetwork = None
def apply_hypernetwork(hypernetwork, context):
hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
if hypernetwork_layers is None:
return context, context
context_k = hypernetwork_layers[0](context)
context_v = hypernetwork_layers[1](context)
return context_k, context_v
def attention_CrossAttention_forward(self, x, context=None, mask=None):
h = self.heads
q = self.to_q(x)
context = default(context, x)
context_k, context_v = apply_hypernetwork(shared.loaded_hypernetwork, context)
k = self.to_k(context_k)
v = self.to_v(context_v)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
if mask is not None:
mask = rearrange(mask, 'b ... -> b (...)')
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, 'b j -> (b h) () j', h=h)
sim.masked_fill_(~mask, max_neg_value)
# attention, what we cannot get enough of
attn = sim.softmax(dim=-1)
out = einsum('b i j, b j d -> b i d', attn, v)
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
return self.to_out(out)

View File

@ -26,10 +26,11 @@ class HypernetworkModule(torch.nn.Module):
if state_dict is not None: if state_dict is not None:
self.load_state_dict(state_dict, strict=True) self.load_state_dict(state_dict, strict=True)
else: else:
self.linear1.weight.data.fill_(0.0001)
self.linear1.bias.data.fill_(0.0001) self.linear1.weight.data.normal_(mean=0.0, std=0.01)
self.linear2.weight.data.fill_(0.0001) self.linear1.bias.data.zero_()
self.linear2.bias.data.fill_(0.0001) self.linear2.weight.data.normal_(mean=0.0, std=0.01)
self.linear2.bias.data.zero_()
self.to(devices.device) self.to(devices.device)
@ -92,41 +93,54 @@ class Hypernetwork:
self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None) self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None)
def load_hypernetworks(path): def list_hypernetworks(path):
res = {} res = {}
for filename in glob.iglob(os.path.join(path, '**/*.pt'), recursive=True):
for filename in glob.iglob(path + '**/*.pt', recursive=True): name = os.path.splitext(os.path.basename(filename))[0]
try: res[name] = filename
hn = Hypernetwork()
hn.load(filename)
res[hn.name] = hn
except Exception:
print(f"Error loading hypernetwork {filename}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
return res return res
def load_hypernetwork(filename):
path = shared.hypernetworks.get(filename, None)
if path is not None:
print(f"Loading hypernetwork {filename}")
try:
shared.loaded_hypernetwork = Hypernetwork()
shared.loaded_hypernetwork.load(path)
except Exception:
print(f"Error loading hypernetwork {path}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
else:
if shared.loaded_hypernetwork is not None:
print(f"Unloading hypernetwork")
shared.loaded_hypernetwork = None
def apply_hypernetwork(hypernetwork, context, layer=None):
hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
if hypernetwork_layers is None:
return context, context
if layer is not None:
layer.hyper_k = hypernetwork_layers[0]
layer.hyper_v = hypernetwork_layers[1]
context_k = hypernetwork_layers[0](context)
context_v = hypernetwork_layers[1](context)
return context_k, context_v
def attention_CrossAttention_forward(self, x, context=None, mask=None): def attention_CrossAttention_forward(self, x, context=None, mask=None):
h = self.heads h = self.heads
q = self.to_q(x) q = self.to_q(x)
context = default(context, x) context = default(context, x)
hypernetwork_layers = (shared.hypernetwork.layers if shared.hypernetwork is not None else {}).get(context.shape[2], None) context_k, context_v = apply_hypernetwork(shared.loaded_hypernetwork, context, self)
if hypernetwork_layers is not None:
hypernetwork_k, hypernetwork_v = hypernetwork_layers
self.hypernetwork_k = hypernetwork_k
self.hypernetwork_v = hypernetwork_v
context_k = hypernetwork_k(context)
context_v = hypernetwork_v(context)
else:
context_k = context
context_v = context
k = self.to_k(context_k) k = self.to_k(context_k)
v = self.to_v(context_v) v = self.to_v(context_v)
@ -151,7 +165,9 @@ def attention_CrossAttention_forward(self, x, context=None, mask=None):
def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, steps, create_image_every, save_hypernetwork_every, template_file, preview_image_prompt): def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, steps, create_image_every, save_hypernetwork_every, template_file, preview_image_prompt):
assert hypernetwork_name, 'embedding not selected' assert hypernetwork_name, 'embedding not selected'
shared.hypernetwork = shared.hypernetworks[hypernetwork_name] path = shared.hypernetworks.get(hypernetwork_name, None)
shared.loaded_hypernetwork = Hypernetwork()
shared.loaded_hypernetwork.load(path)
shared.state.textinfo = "Initializing hypernetwork training..." shared.state.textinfo = "Initializing hypernetwork training..."
shared.state.job_count = steps shared.state.job_count = steps
@ -176,9 +192,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"): with torch.autocast("cuda"):
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, size=512, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file) ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=512, height=512, repeats=1, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file)
hypernetwork = shared.hypernetworks[hypernetwork_name] hypernetwork = shared.loaded_hypernetwork
weights = hypernetwork.weights() weights = hypernetwork.weights()
for weight in weights: for weight in weights:
weight.requires_grad = True weight.requires_grad = True

View File

@ -6,24 +6,24 @@ import gradio as gr
import modules.textual_inversion.textual_inversion import modules.textual_inversion.textual_inversion
import modules.textual_inversion.preprocess import modules.textual_inversion.preprocess
from modules import sd_hijack, shared from modules import sd_hijack, shared
from modules.hypernetwork import hypernetwork
def create_hypernetwork(name): def create_hypernetwork(name):
fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt") fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
assert not os.path.exists(fn), f"file {fn} already exists" assert not os.path.exists(fn), f"file {fn} already exists"
hypernetwork = modules.hypernetwork.hypernetwork.Hypernetwork(name=name) hypernet = modules.hypernetwork.hypernetwork.Hypernetwork(name=name)
hypernetwork.save(fn) hypernet.save(fn)
shared.reload_hypernetworks() shared.reload_hypernetworks()
shared.hypernetwork = shared.hypernetworks.get(shared.opts.sd_hypernetwork, None)
return gr.Dropdown.update(choices=sorted([x for x in shared.hypernetworks.keys()])), f"Created: {fn}", "" return gr.Dropdown.update(choices=sorted([x for x in shared.hypernetworks.keys()])), f"Created: {fn}", ""
def train_hypernetwork(*args): def train_hypernetwork(*args):
initial_hypernetwork = shared.hypernetwork initial_hypernetwork = shared.loaded_hypernetwork
try: try:
sd_hijack.undo_optimizations() sd_hijack.undo_optimizations()
@ -38,6 +38,6 @@ Hypernetwork saved to {html.escape(filename)}
except Exception: except Exception:
raise raise
finally: finally:
shared.hypernetwork = initial_hypernetwork shared.loaded_hypernetwork = initial_hypernetwork
sd_hijack.apply_optimizations() sd_hijack.apply_optimizations()

View File

@ -8,7 +8,8 @@ from torch import einsum
from ldm.util import default from ldm.util import default
from einops import rearrange from einops import rearrange
from modules import shared, hypernetwork from modules import shared
from modules.hypernetwork import hypernetwork
if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers: if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers:

View File

@ -13,7 +13,8 @@ import modules.memmon
import modules.sd_models import modules.sd_models
import modules.styles import modules.styles
import modules.devices as devices import modules.devices as devices
from modules import sd_samplers, hypernetwork from modules import sd_samplers
from modules.hypernetwork import hypernetwork
from modules.paths import models_path, script_path, sd_path from modules.paths import models_path, script_path, sd_path
sd_model_file = os.path.join(script_path, 'model.ckpt') sd_model_file = os.path.join(script_path, 'model.ckpt')
@ -29,6 +30,7 @@ parser.add_argument("--no-half-vae", action='store_true', help="do not switch th
parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)") parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)")
parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI") parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI")
parser.add_argument("--embeddings-dir", type=str, default=os.path.join(script_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)") parser.add_argument("--embeddings-dir", type=str, default=os.path.join(script_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)")
parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_path, 'hypernetworks'), help="hypernetwork directory")
parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui") parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui")
parser.add_argument("--medvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a little speed for low VRM usage") parser.add_argument("--medvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a little speed for low VRM usage")
parser.add_argument("--lowvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a lot of speed for very low VRM usage") parser.add_argument("--lowvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a lot of speed for very low VRM usage")
@ -82,10 +84,17 @@ parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram
xformers_available = False xformers_available = False
config_filename = cmd_opts.ui_settings_file config_filename = cmd_opts.ui_settings_file
hypernetworks = hypernetwork.list_hypernetworks(os.path.join(models_path, 'hypernetworks')) hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir)
loaded_hypernetwork = None loaded_hypernetwork = None
def reload_hypernetworks():
global hypernetworks
hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir)
hypernetwork.load_hypernetwork(opts.sd_hypernetwork)
class State: class State:
skipped = False skipped = False
interrupted = False interrupted = False

View File

@ -156,7 +156,7 @@ def create_embedding(name, num_vectors_per_token, init_text='*'):
return fn return fn
def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, num_repeats, create_image_every, save_embedding_every, template_file): def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, num_repeats, create_image_every, save_embedding_every, template_file, preview_image_prompt):
assert embedding_name, 'embedding not selected' assert embedding_name, 'embedding not selected'
shared.state.textinfo = "Initializing textual inversion training..." shared.state.textinfo = "Initializing textual inversion training..."
@ -238,9 +238,11 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0: if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0:
last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png') last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png')
preview_text = text if preview_image_prompt == "" else preview_image_prompt
p = processing.StableDiffusionProcessingTxt2Img( p = processing.StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model, sd_model=shared.sd_model,
prompt=text, prompt=preview_text,
steps=20, steps=20,
height=training_height, height=training_height,
width=training_width, width=training_width,
@ -254,7 +256,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
shared.state.current_image = image shared.state.current_image = image
image.save(last_saved_image) image.save(last_saved_image)
last_saved_image += f", prompt: {text}" last_saved_image += f", prompt: {preview_text}"
shared.state.job_no = embedding.step shared.state.job_no = embedding.step

View File

@ -1023,7 +1023,7 @@ def create_ui(wrap_gradio_gpu_call):
gr.HTML(value="") gr.HTML(value="")
with gr.Column(): with gr.Column():
create_embedding = gr.Button(value="Create", variant='primary') create_embedding = gr.Button(value="Create embedding", variant='primary')
with gr.Group(): with gr.Group():
gr.HTML(value="<p style='margin-bottom: 0.7em'>Create a new hypernetwork</p>") gr.HTML(value="<p style='margin-bottom: 0.7em'>Create a new hypernetwork</p>")
@ -1035,7 +1035,7 @@ def create_ui(wrap_gradio_gpu_call):
gr.HTML(value="") gr.HTML(value="")
with gr.Column(): with gr.Column():
create_hypernetwork = gr.Button(value="Create", variant='primary') create_hypernetwork = gr.Button(value="Create hypernetwork", variant='primary')
with gr.Group(): with gr.Group():
gr.HTML(value="<p style='margin-bottom: 0.7em'>Preprocess images</p>") gr.HTML(value="<p style='margin-bottom: 0.7em'>Preprocess images</p>")
@ -1147,6 +1147,7 @@ def create_ui(wrap_gradio_gpu_call):
create_image_every, create_image_every,
save_embedding_every, save_embedding_every,
template_file, template_file,
preview_image_prompt,
], ],
outputs=[ outputs=[
ti_output, ti_output,

View File

@ -10,7 +10,8 @@ import numpy as np
import modules.scripts as scripts import modules.scripts as scripts
import gradio as gr import gradio as gr
from modules import images, hypernetwork from modules import images
from modules.hypernetwork import hypernetwork
from modules.processing import process_images, Processed, get_correct_sampler from modules.processing import process_images, Processed, get_correct_sampler
from modules.shared import opts, cmd_opts, state from modules.shared import opts, cmd_opts, state
import modules.shared as shared import modules.shared as shared

View File

@ -29,6 +29,7 @@ from modules import devices
from modules import modelloader from modules import modelloader
from modules.paths import script_path from modules.paths import script_path
from modules.shared import cmd_opts from modules.shared import cmd_opts
import modules.hypernetwork.hypernetwork
modelloader.cleanup_models() modelloader.cleanup_models()
modules.sd_models.setup_model() modules.sd_models.setup_model()
@ -77,22 +78,12 @@ def wrap_gradio_gpu_call(func, extra_outputs=None):
return modules.ui.wrap_gradio_call(f, extra_outputs=extra_outputs) return modules.ui.wrap_gradio_call(f, extra_outputs=extra_outputs)
def set_hypernetwork():
shared.hypernetwork = shared.hypernetworks.get(shared.opts.sd_hypernetwork, None)
shared.reload_hypernetworks()
shared.opts.onchange("sd_hypernetwork", set_hypernetwork)
set_hypernetwork()
modules.scripts.load_scripts(os.path.join(script_path, "scripts")) modules.scripts.load_scripts(os.path.join(script_path, "scripts"))
shared.sd_model = modules.sd_models.load_model() shared.sd_model = modules.sd_models.load_model()
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(shared.sd_model))) shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(shared.sd_model)))
loaded_hypernetwork = modules.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork) shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetwork.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork)))
shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork)))
def webui(): def webui():