implementation for attention using [] and ()

This commit is contained in:
AUTOMATIC 2022-08-27 11:17:55 +03:00
parent a51bedfb5a
commit 9597b265ec
3 changed files with 62 additions and 23 deletions

View File

@ -188,3 +188,9 @@ and put it into `embeddings` dir and use Usada Pekora in prompt.
A tab with settings, allowing you to use UI to edit more than half of parameters that previously
were commandline. Settings are saved to config.js file. Settings that remain as commandline
options are ones that are required at startup.
### Attention
Using `()` in prompt decreases model's attention to enclosed words, and `[]` increases it. You can combine
multiple modifiers:
![](images/attention-3.jpg)

BIN
images/attention-3.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 944 KiB

View File

@ -433,15 +433,15 @@ if os.path.exists(cmd_opts.gfpgan_dir):
print(traceback.format_exc(), file=sys.stderr)
class TextInversionEmbeddings:
class StableDiffuionModelHijack:
ids_lookup = {}
word_embeddings = {}
word_embeddings_checksums = {}
fixes = []
fixes = None
used_custom_terms = []
dir_mtime = None
def load(self, dir, model):
def load_textual_inversion_embeddings(self, dir, model):
mt = os.path.getmtime(dir)
if self.dir_mtime is not None and mt <= self.dir_mtime:
return
@ -469,6 +469,7 @@ class TextInversionEmbeddings:
self.word_embeddings_checksums[name] = f'{const_hash(emb)&0xffff:04x}'
ids = tokenizer([name], add_special_tokens=False)['input_ids'][0]
first_id = ids[0]
if first_id not in self.ids_lookup:
self.ids_lookup[first_id] = []
@ -497,6 +498,23 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
self.embeddings = embeddings
self.tokenizer = wrapped.tokenizer
self.max_length = wrapped.max_length
self.token_mults = {}
tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k]
for text, ident in tokens_with_parens:
mult = 1.0
for c in text:
if c == '[':
mult /= 1.1
if c == ']':
mult *= 1.1
if c == '(':
mult *= 1.1
if c == ')':
mult /= 1.1
if mult != 1.0:
self.token_mults[ident] = mult
def forward(self, text):
self.embeddings.fixes = []
@ -508,14 +526,17 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
cache = {}
batch_tokens = self.wrapped.tokenizer(text, truncation=False, add_special_tokens=False)["input_ids"]
batch_multipliers = []
for tokens in batch_tokens:
tuple_tokens = tuple(tokens)
if tuple_tokens in cache:
remade_tokens, fixes = cache[tuple_tokens]
remade_tokens, fixes, multipliers = cache[tuple_tokens]
else:
fixes = []
remade_tokens = []
multipliers = []
mult = 1.0
i = 0
while i < len(tokens):
@ -523,14 +544,19 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
possible_matches = self.embeddings.ids_lookup.get(token, None)
if possible_matches is None:
mult_change = self.token_mults.get(token)
if mult_change is not None:
mult *= mult_change
elif possible_matches is None:
remade_tokens.append(token)
multipliers.append(mult)
else:
found = False
for ids, word in possible_matches:
if tokens[i:i+len(ids)] == ids:
fixes.append((len(remade_tokens), word))
remade_tokens.append(777)
multipliers.append(mult)
i += len(ids) - 1
found = True
self.embeddings.used_custom_terms.append((word, self.embeddings.word_embeddings_checksums[word]))
@ -538,19 +564,32 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
if not found:
remade_tokens.append(token)
multipliers.append(mult)
i += 1
remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
remade_tokens = [id_start] + remade_tokens[0:maxlen-2] + [id_end]
cache[tuple_tokens] = (remade_tokens, fixes)
cache[tuple_tokens] = (remade_tokens, fixes, multipliers)
multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
remade_batch_tokens.append(remade_tokens)
self.embeddings.fixes.append(fixes)
batch_multipliers.append(multipliers)
tokens = torch.asarray(remade_batch_tokens).to(self.wrapped.device)
outputs = self.wrapped.transformer(input_ids=tokens)
z = outputs.last_hidden_state
# restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
batch_multipliers = torch.asarray(np.array(batch_multipliers)).to(device)
original_mean = z.mean()
z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
new_mean = z.mean()
z *= original_mean / new_mean
return z
@ -562,24 +601,19 @@ class EmbeddingsWithFixes(nn.Module):
def forward(self, input_ids):
batch_fixes = self.embeddings.fixes
self.embeddings.fixes = []
self.embeddings.fixes = None
inputs_embeds = self.wrapped(input_ids)
if batch_fixes is not None:
for fixes, tensor in zip(batch_fixes, inputs_embeds):
for offset, word in fixes:
tensor[offset] = self.embeddings.word_embeddings[word]
return inputs_embeds
def get_learned_conditioning_with_embeddings(model, prompts):
if os.path.exists(cmd_opts.embeddings_dir):
text_inversion_embeddings.load(cmd_opts.embeddings_dir, model)
return model.get_learned_conditioning(prompts)
def process_images(outpath, func_init, func_sample, prompt, seed, sampler_index, batch_size, n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN, do_not_save_grid=False, extra_generation_params=None):
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
@ -648,7 +682,7 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_index,
return f"{prompt}\n{generation_params_text}".strip() + "".join(["\n\n" + x for x in comments])
if os.path.exists(cmd_opts.embeddings_dir):
text_inversion_embeddings.load(cmd_opts.embeddings_dir, model)
model_hijack.load_textual_inversion_embeddings(cmd_opts.embeddings_dir, model)
output_images = []
with torch.no_grad(), autocast("cuda"), model.ema_scope():
@ -661,8 +695,8 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_index,
uc = model.get_learned_conditioning(len(prompts) * [""])
c = model.get_learned_conditioning(prompts)
if len(text_inversion_embeddings.used_custom_terms) > 0:
comments.append("Used custom terms: " + ", ".join([f'{word} [{checksum}]' for word, checksum in text_inversion_embeddings.used_custom_terms]))
if len(model_hijack.used_custom_terms) > 0:
comments.append("Used custom terms: " + ", ".join([f'{word} [{checksum}]' for word, checksum in model_hijack.used_custom_terms]))
# we manually generate all input noises because each one should have a specific seed
x = create_random_tensors([opt_C, height // opt_f, width // opt_f], seeds=seeds)
@ -1060,10 +1094,9 @@ model = load_model_from_config(config, cmd_opts.ckpt)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = (model if cmd_opts.no_half else model.half()).to(device)
text_inversion_embeddings = TextInversionEmbeddings()
if os.path.exists(cmd_opts.embeddings_dir):
text_inversion_embeddings.hijack(model)
model_hijack = StableDiffuionModelHijack()
model_hijack.hijack(model)
demo = gr.TabbedInterface(
interface_list=[x[0] for x in interfaces],