stable-diffusion-webui/modules/sd_hijack.py

251 lines
9.2 KiB
Python
Raw Normal View History

import os
import sys
import traceback
import torch
import numpy as np
from torch import einsum
from modules.shared import opts, device, cmd_opts
from ldm.util import default
from einops import rearrange
import ldm.modules.attention
# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
def split_cross_attention_forward(self, x, context=None, mask=None):
h = self.heads
q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
del context, x
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
for i in range(0, q.shape[0], 2):
end = i + 2
s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
s1 *= self.scale
s2 = s1.softmax(dim=-1)
del s1
r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
del s2
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
del r1
return self.to_out(r2)
class StableDiffusionModelHijack:
ids_lookup = {}
word_embeddings = {}
word_embeddings_checksums = {}
fixes = None
comments = []
dir_mtime = None
def load_textual_inversion_embeddings(self, dirname, model):
mt = os.path.getmtime(dirname)
if self.dir_mtime is not None and mt <= self.dir_mtime:
return
self.dir_mtime = mt
self.ids_lookup.clear()
self.word_embeddings.clear()
tokenizer = model.cond_stage_model.tokenizer
def const_hash(a):
r = 0
for v in a:
r = (r * 281 ^ int(v) * 997) & 0xFFFFFFFF
return r
def process_file(path, filename):
name = os.path.splitext(filename)[0]
data = torch.load(path)
param_dict = data['string_to_param']
if hasattr(param_dict, '_parameters'):
param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
emb = next(iter(param_dict.items()))[1]
self.word_embeddings[name] = emb.detach()
self.word_embeddings_checksums[name] = f'{const_hash(emb.reshape(-1))&0xffff:04x}'
ids = tokenizer([name], add_special_tokens=False)['input_ids'][0]
first_id = ids[0]
if first_id not in self.ids_lookup:
self.ids_lookup[first_id] = []
self.ids_lookup[first_id].append((ids, name))
for fn in os.listdir(dirname):
try:
process_file(os.path.join(dirname, fn), fn)
except Exception:
print(f"Error loading emedding {fn}:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
continue
print(f"Loaded a total of {len(self.word_embeddings)} text inversion embeddings.")
def hijack(self, m):
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
m.cond_stage_model = FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
if cmd_opts.opt_split_attention:
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward
class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
def __init__(self, wrapped, hijack):
super().__init__()
self.wrapped = wrapped
self.hijack = hijack
self.tokenizer = wrapped.tokenizer
self.max_length = wrapped.max_length
self.token_mults = {}
tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k]
for text, ident in tokens_with_parens:
mult = 1.0
for c in text:
if c == '[':
mult /= 1.1
if c == ']':
mult *= 1.1
if c == '(':
mult *= 1.1
if c == ')':
mult /= 1.1
if mult != 1.0:
self.token_mults[ident] = mult
def forward(self, text):
self.hijack.fixes = []
self.hijack.comments = []
remade_batch_tokens = []
id_start = self.wrapped.tokenizer.bos_token_id
id_end = self.wrapped.tokenizer.eos_token_id
maxlen = self.wrapped.max_length - 2
used_custom_terms = []
cache = {}
batch_tokens = self.wrapped.tokenizer(text, truncation=False, add_special_tokens=False)["input_ids"]
batch_multipliers = []
for tokens in batch_tokens:
tuple_tokens = tuple(tokens)
if tuple_tokens in cache:
remade_tokens, fixes, multipliers = cache[tuple_tokens]
else:
fixes = []
remade_tokens = []
multipliers = []
mult = 1.0
i = 0
while i < len(tokens):
token = tokens[i]
possible_matches = self.hijack.ids_lookup.get(token, None)
mult_change = self.token_mults.get(token) if opts.enable_emphasis else None
if mult_change is not None:
mult *= mult_change
elif possible_matches is None:
remade_tokens.append(token)
multipliers.append(mult)
else:
found = False
for ids, word in possible_matches:
if tokens[i:i+len(ids)] == ids:
emb_len = int(self.hijack.word_embeddings[word].shape[0])
fixes.append((len(remade_tokens), word))
remade_tokens += [0] * emb_len
multipliers += [mult] * emb_len
i += len(ids) - 1
found = True
used_custom_terms.append((word, self.hijack.word_embeddings_checksums[word]))
break
if not found:
remade_tokens.append(token)
multipliers.append(mult)
i += 1
if len(remade_tokens) > maxlen - 2:
vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
ovf = remade_tokens[maxlen - 2:]
overflowing_words = [vocab.get(int(x), "") for x in ovf]
overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
self.hijack.comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
remade_tokens = [id_start] + remade_tokens[0:maxlen-2] + [id_end]
cache[tuple_tokens] = (remade_tokens, fixes, multipliers)
multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
remade_batch_tokens.append(remade_tokens)
self.hijack.fixes.append(fixes)
batch_multipliers.append(multipliers)
if len(used_custom_terms) > 0:
self.hijack.comments.append("Used custom terms: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
tokens = torch.asarray(remade_batch_tokens).to(device)
outputs = self.wrapped.transformer(input_ids=tokens)
z = outputs.last_hidden_state
# restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
batch_multipliers = torch.asarray(np.array(batch_multipliers)).to(device)
original_mean = z.mean()
z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
new_mean = z.mean()
z *= original_mean / new_mean
return z
class EmbeddingsWithFixes(torch.nn.Module):
def __init__(self, wrapped, embeddings):
super().__init__()
self.wrapped = wrapped
self.embeddings = embeddings
def forward(self, input_ids):
batch_fixes = self.embeddings.fixes
self.embeddings.fixes = None
inputs_embeds = self.wrapped(input_ids)
if batch_fixes is not None:
for fixes, tensor in zip(batch_fixes, inputs_embeds):
for offset, word in fixes:
emb = self.embeddings.word_embeddings[word]
emb_len = min(tensor.shape[0]-offset, emb.shape[0])
tensor[offset:offset+emb_len] = self.embeddings.word_embeddings[word][0:emb_len]
return inputs_embeds
model_hijack = StableDiffusionModelHijack()