fix for incorrect embedding token length calculation (will break seeds that use embeddings, you're welcome!)

add option to input initialization text for embeddings
This commit is contained in:
AUTOMATIC 2022-10-02 19:40:51 +03:00
parent 53a3dc601f
commit 88ec0cf557
4 changed files with 13 additions and 14 deletions

View File

@ -130,7 +130,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
while i < len(tokens): while i < len(tokens):
token = tokens[i] token = tokens[i]
embedding = self.hijack.embedding_db.find_embedding_at_position(tokens, i) embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
if embedding is None: if embedding is None:
remade_tokens.append(token) remade_tokens.append(token)
@ -142,7 +142,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
remade_tokens += [0] * emb_len remade_tokens += [0] * emb_len
multipliers += [weight] * emb_len multipliers += [weight] * emb_len
used_custom_terms.append((embedding.name, embedding.checksum())) used_custom_terms.append((embedding.name, embedding.checksum()))
i += emb_len i += embedding_length_in_tokens
if len(remade_tokens) > maxlen - 2: if len(remade_tokens) > maxlen - 2:
vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()} vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
@ -213,7 +213,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
while i < len(tokens): while i < len(tokens):
token = tokens[i] token = tokens[i]
embedding = self.hijack.embedding_db.find_embedding_at_position(tokens, i) embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
mult_change = self.token_mults.get(token) if opts.enable_emphasis else None mult_change = self.token_mults.get(token) if opts.enable_emphasis else None
if mult_change is not None: if mult_change is not None:
@ -229,7 +229,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
remade_tokens += [0] * emb_len remade_tokens += [0] * emb_len
multipliers += [mult] * emb_len multipliers += [mult] * emb_len
used_custom_terms.append((embedding.name, embedding.checksum())) used_custom_terms.append((embedding.name, embedding.checksum()))
i += emb_len i += embedding_length_in_tokens
if len(remade_tokens) > maxlen - 2: if len(remade_tokens) > maxlen - 2:
vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()} vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}

View File

@ -117,24 +117,21 @@ class EmbeddingDatabase:
possible_matches = self.ids_lookup.get(token, None) possible_matches = self.ids_lookup.get(token, None)
if possible_matches is None: if possible_matches is None:
return None return None, None
for ids, embedding in possible_matches: for ids, embedding in possible_matches:
if tokens[offset:offset + len(ids)] == ids: if tokens[offset:offset + len(ids)] == ids:
return embedding return embedding, len(ids)
return None return None, None
def create_embedding(name, num_vectors_per_token, init_text='*'):
def create_embedding(name, num_vectors_per_token):
init_text = '*'
cond_model = shared.sd_model.cond_stage_model cond_model = shared.sd_model.cond_stage_model
embedding_layer = cond_model.wrapped.transformer.text_model.embeddings embedding_layer = cond_model.wrapped.transformer.text_model.embeddings
ids = cond_model.tokenizer(init_text, max_length=num_vectors_per_token, return_tensors="pt", add_special_tokens=False)["input_ids"] ids = cond_model.tokenizer(init_text, max_length=num_vectors_per_token, return_tensors="pt", add_special_tokens=False)["input_ids"]
embedded = embedding_layer(ids.to(devices.device)).squeeze(0) embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0)
vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device) vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device)
for i in range(num_vectors_per_token): for i in range(num_vectors_per_token):

View File

@ -6,8 +6,8 @@ import modules.textual_inversion.textual_inversion as ti
from modules import sd_hijack, shared from modules import sd_hijack, shared
def create_embedding(name, nvpt): def create_embedding(name, initialization_text, nvpt):
filename = ti.create_embedding(name, nvpt) filename = ti.create_embedding(name, nvpt, init_text=initialization_text)
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()

View File

@ -954,6 +954,7 @@ def create_ui(wrap_gradio_gpu_call):
gr.HTML(value="<p style='margin-bottom: 0.7em'>Create a new embedding</p>") gr.HTML(value="<p style='margin-bottom: 0.7em'>Create a new embedding</p>")
new_embedding_name = gr.Textbox(label="Name") new_embedding_name = gr.Textbox(label="Name")
initialization_text = gr.Textbox(label="Initialization text", value="*")
nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1) nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1)
with gr.Row(): with gr.Row():
@ -997,6 +998,7 @@ def create_ui(wrap_gradio_gpu_call):
fn=modules.textual_inversion.ui.create_embedding, fn=modules.textual_inversion.ui.create_embedding,
inputs=[ inputs=[
new_embedding_name, new_embedding_name,
initialization_text,
nvpt, nvpt,
], ],
outputs=[ outputs=[