implementation for attention using [] and ()
This commit is contained in:
parent
a51bedfb5a
commit
9597b265ec
|
@ -188,3 +188,9 @@ and put it into `embeddings` dir and use Usada Pekora in prompt.
|
||||||
A tab with settings, allowing you to use UI to edit more than half of parameters that previously
|
A tab with settings, allowing you to use UI to edit more than half of parameters that previously
|
||||||
were commandline. Settings are saved to config.js file. Settings that remain as commandline
|
were commandline. Settings are saved to config.js file. Settings that remain as commandline
|
||||||
options are ones that are required at startup.
|
options are ones that are required at startup.
|
||||||
|
|
||||||
|
### Attention
|
||||||
|
Using `()` in prompt decreases model's attention to enclosed words, and `[]` increases it. You can combine
|
||||||
|
multiple modifiers:
|
||||||
|
|
||||||
|
![](images/attention-3.jpg)
|
||||||
|
|
BIN
images/attention-3.jpg
Normal file
BIN
images/attention-3.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 944 KiB |
73
webui.py
73
webui.py
|
@ -433,15 +433,15 @@ if os.path.exists(cmd_opts.gfpgan_dir):
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
|
|
||||||
|
|
||||||
class TextInversionEmbeddings:
|
class StableDiffuionModelHijack:
|
||||||
ids_lookup = {}
|
ids_lookup = {}
|
||||||
word_embeddings = {}
|
word_embeddings = {}
|
||||||
word_embeddings_checksums = {}
|
word_embeddings_checksums = {}
|
||||||
fixes = []
|
fixes = None
|
||||||
used_custom_terms = []
|
used_custom_terms = []
|
||||||
dir_mtime = None
|
dir_mtime = None
|
||||||
|
|
||||||
def load(self, dir, model):
|
def load_textual_inversion_embeddings(self, dir, model):
|
||||||
mt = os.path.getmtime(dir)
|
mt = os.path.getmtime(dir)
|
||||||
if self.dir_mtime is not None and mt <= self.dir_mtime:
|
if self.dir_mtime is not None and mt <= self.dir_mtime:
|
||||||
return
|
return
|
||||||
|
@ -469,6 +469,7 @@ class TextInversionEmbeddings:
|
||||||
self.word_embeddings_checksums[name] = f'{const_hash(emb)&0xffff:04x}'
|
self.word_embeddings_checksums[name] = f'{const_hash(emb)&0xffff:04x}'
|
||||||
|
|
||||||
ids = tokenizer([name], add_special_tokens=False)['input_ids'][0]
|
ids = tokenizer([name], add_special_tokens=False)['input_ids'][0]
|
||||||
|
|
||||||
first_id = ids[0]
|
first_id = ids[0]
|
||||||
if first_id not in self.ids_lookup:
|
if first_id not in self.ids_lookup:
|
||||||
self.ids_lookup[first_id] = []
|
self.ids_lookup[first_id] = []
|
||||||
|
@ -497,6 +498,23 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
||||||
self.embeddings = embeddings
|
self.embeddings = embeddings
|
||||||
self.tokenizer = wrapped.tokenizer
|
self.tokenizer = wrapped.tokenizer
|
||||||
self.max_length = wrapped.max_length
|
self.max_length = wrapped.max_length
|
||||||
|
self.token_mults = {}
|
||||||
|
|
||||||
|
tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k]
|
||||||
|
for text, ident in tokens_with_parens:
|
||||||
|
mult = 1.0
|
||||||
|
for c in text:
|
||||||
|
if c == '[':
|
||||||
|
mult /= 1.1
|
||||||
|
if c == ']':
|
||||||
|
mult *= 1.1
|
||||||
|
if c == '(':
|
||||||
|
mult *= 1.1
|
||||||
|
if c == ')':
|
||||||
|
mult /= 1.1
|
||||||
|
|
||||||
|
if mult != 1.0:
|
||||||
|
self.token_mults[ident] = mult
|
||||||
|
|
||||||
def forward(self, text):
|
def forward(self, text):
|
||||||
self.embeddings.fixes = []
|
self.embeddings.fixes = []
|
||||||
|
@ -508,14 +526,17 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
||||||
|
|
||||||
cache = {}
|
cache = {}
|
||||||
batch_tokens = self.wrapped.tokenizer(text, truncation=False, add_special_tokens=False)["input_ids"]
|
batch_tokens = self.wrapped.tokenizer(text, truncation=False, add_special_tokens=False)["input_ids"]
|
||||||
|
batch_multipliers = []
|
||||||
for tokens in batch_tokens:
|
for tokens in batch_tokens:
|
||||||
tuple_tokens = tuple(tokens)
|
tuple_tokens = tuple(tokens)
|
||||||
|
|
||||||
if tuple_tokens in cache:
|
if tuple_tokens in cache:
|
||||||
remade_tokens, fixes = cache[tuple_tokens]
|
remade_tokens, fixes, multipliers = cache[tuple_tokens]
|
||||||
else:
|
else:
|
||||||
fixes = []
|
fixes = []
|
||||||
remade_tokens = []
|
remade_tokens = []
|
||||||
|
multipliers = []
|
||||||
|
mult = 1.0
|
||||||
|
|
||||||
i = 0
|
i = 0
|
||||||
while i < len(tokens):
|
while i < len(tokens):
|
||||||
|
@ -523,14 +544,19 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
||||||
|
|
||||||
possible_matches = self.embeddings.ids_lookup.get(token, None)
|
possible_matches = self.embeddings.ids_lookup.get(token, None)
|
||||||
|
|
||||||
if possible_matches is None:
|
mult_change = self.token_mults.get(token)
|
||||||
|
if mult_change is not None:
|
||||||
|
mult *= mult_change
|
||||||
|
elif possible_matches is None:
|
||||||
remade_tokens.append(token)
|
remade_tokens.append(token)
|
||||||
|
multipliers.append(mult)
|
||||||
else:
|
else:
|
||||||
found = False
|
found = False
|
||||||
for ids, word in possible_matches:
|
for ids, word in possible_matches:
|
||||||
if tokens[i:i+len(ids)] == ids:
|
if tokens[i:i+len(ids)] == ids:
|
||||||
fixes.append((len(remade_tokens), word))
|
fixes.append((len(remade_tokens), word))
|
||||||
remade_tokens.append(777)
|
remade_tokens.append(777)
|
||||||
|
multipliers.append(mult)
|
||||||
i += len(ids) - 1
|
i += len(ids) - 1
|
||||||
found = True
|
found = True
|
||||||
self.embeddings.used_custom_terms.append((word, self.embeddings.word_embeddings_checksums[word]))
|
self.embeddings.used_custom_terms.append((word, self.embeddings.word_embeddings_checksums[word]))
|
||||||
|
@ -538,19 +564,32 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
||||||
|
|
||||||
if not found:
|
if not found:
|
||||||
remade_tokens.append(token)
|
remade_tokens.append(token)
|
||||||
|
multipliers.append(mult)
|
||||||
|
|
||||||
i += 1
|
i += 1
|
||||||
|
|
||||||
remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
|
remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
|
||||||
remade_tokens = [id_start] + remade_tokens[0:maxlen-2] + [id_end]
|
remade_tokens = [id_start] + remade_tokens[0:maxlen-2] + [id_end]
|
||||||
cache[tuple_tokens] = (remade_tokens, fixes)
|
cache[tuple_tokens] = (remade_tokens, fixes, multipliers)
|
||||||
|
|
||||||
|
multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
|
||||||
|
multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
|
||||||
|
|
||||||
remade_batch_tokens.append(remade_tokens)
|
remade_batch_tokens.append(remade_tokens)
|
||||||
self.embeddings.fixes.append(fixes)
|
self.embeddings.fixes.append(fixes)
|
||||||
|
batch_multipliers.append(multipliers)
|
||||||
|
|
||||||
tokens = torch.asarray(remade_batch_tokens).to(self.wrapped.device)
|
tokens = torch.asarray(remade_batch_tokens).to(self.wrapped.device)
|
||||||
outputs = self.wrapped.transformer(input_ids=tokens)
|
outputs = self.wrapped.transformer(input_ids=tokens)
|
||||||
z = outputs.last_hidden_state
|
z = outputs.last_hidden_state
|
||||||
|
|
||||||
|
# restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
|
||||||
|
batch_multipliers = torch.asarray(np.array(batch_multipliers)).to(device)
|
||||||
|
original_mean = z.mean()
|
||||||
|
z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
|
||||||
|
new_mean = z.mean()
|
||||||
|
z *= original_mean / new_mean
|
||||||
|
|
||||||
return z
|
return z
|
||||||
|
|
||||||
|
|
||||||
|
@ -562,24 +601,19 @@ class EmbeddingsWithFixes(nn.Module):
|
||||||
|
|
||||||
def forward(self, input_ids):
|
def forward(self, input_ids):
|
||||||
batch_fixes = self.embeddings.fixes
|
batch_fixes = self.embeddings.fixes
|
||||||
self.embeddings.fixes = []
|
self.embeddings.fixes = None
|
||||||
|
|
||||||
inputs_embeds = self.wrapped(input_ids)
|
inputs_embeds = self.wrapped(input_ids)
|
||||||
|
|
||||||
|
if batch_fixes is not None:
|
||||||
for fixes, tensor in zip(batch_fixes, inputs_embeds):
|
for fixes, tensor in zip(batch_fixes, inputs_embeds):
|
||||||
for offset, word in fixes:
|
for offset, word in fixes:
|
||||||
tensor[offset] = self.embeddings.word_embeddings[word]
|
tensor[offset] = self.embeddings.word_embeddings[word]
|
||||||
|
|
||||||
|
|
||||||
return inputs_embeds
|
return inputs_embeds
|
||||||
|
|
||||||
|
|
||||||
def get_learned_conditioning_with_embeddings(model, prompts):
|
|
||||||
if os.path.exists(cmd_opts.embeddings_dir):
|
|
||||||
text_inversion_embeddings.load(cmd_opts.embeddings_dir, model)
|
|
||||||
|
|
||||||
return model.get_learned_conditioning(prompts)
|
|
||||||
|
|
||||||
|
|
||||||
def process_images(outpath, func_init, func_sample, prompt, seed, sampler_index, batch_size, n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN, do_not_save_grid=False, extra_generation_params=None):
|
def process_images(outpath, func_init, func_sample, prompt, seed, sampler_index, batch_size, n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN, do_not_save_grid=False, extra_generation_params=None):
|
||||||
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
|
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
|
||||||
|
|
||||||
|
@ -648,7 +682,7 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_index,
|
||||||
return f"{prompt}\n{generation_params_text}".strip() + "".join(["\n\n" + x for x in comments])
|
return f"{prompt}\n{generation_params_text}".strip() + "".join(["\n\n" + x for x in comments])
|
||||||
|
|
||||||
if os.path.exists(cmd_opts.embeddings_dir):
|
if os.path.exists(cmd_opts.embeddings_dir):
|
||||||
text_inversion_embeddings.load(cmd_opts.embeddings_dir, model)
|
model_hijack.load_textual_inversion_embeddings(cmd_opts.embeddings_dir, model)
|
||||||
|
|
||||||
output_images = []
|
output_images = []
|
||||||
with torch.no_grad(), autocast("cuda"), model.ema_scope():
|
with torch.no_grad(), autocast("cuda"), model.ema_scope():
|
||||||
|
@ -661,8 +695,8 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_index,
|
||||||
uc = model.get_learned_conditioning(len(prompts) * [""])
|
uc = model.get_learned_conditioning(len(prompts) * [""])
|
||||||
c = model.get_learned_conditioning(prompts)
|
c = model.get_learned_conditioning(prompts)
|
||||||
|
|
||||||
if len(text_inversion_embeddings.used_custom_terms) > 0:
|
if len(model_hijack.used_custom_terms) > 0:
|
||||||
comments.append("Used custom terms: " + ", ".join([f'{word} [{checksum}]' for word, checksum in text_inversion_embeddings.used_custom_terms]))
|
comments.append("Used custom terms: " + ", ".join([f'{word} [{checksum}]' for word, checksum in model_hijack.used_custom_terms]))
|
||||||
|
|
||||||
# we manually generate all input noises because each one should have a specific seed
|
# we manually generate all input noises because each one should have a specific seed
|
||||||
x = create_random_tensors([opt_C, height // opt_f, width // opt_f], seeds=seeds)
|
x = create_random_tensors([opt_C, height // opt_f, width // opt_f], seeds=seeds)
|
||||||
|
@ -1060,10 +1094,9 @@ model = load_model_from_config(config, cmd_opts.ckpt)
|
||||||
|
|
||||||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||||
model = (model if cmd_opts.no_half else model.half()).to(device)
|
model = (model if cmd_opts.no_half else model.half()).to(device)
|
||||||
text_inversion_embeddings = TextInversionEmbeddings()
|
|
||||||
|
|
||||||
if os.path.exists(cmd_opts.embeddings_dir):
|
model_hijack = StableDiffuionModelHijack()
|
||||||
text_inversion_embeddings.hijack(model)
|
model_hijack.hijack(model)
|
||||||
|
|
||||||
demo = gr.TabbedInterface(
|
demo = gr.TabbedInterface(
|
||||||
interface_list=[x[0] for x in interfaces],
|
interface_list=[x[0] for x in interfaces],
|
||||||
|
|
Loading…
Reference in New Issue
Block a user