fix for incorrect embedding token length calculation (will break seeds that use embeddings, you're welcome!)
add option to input initialization text for embeddings
This commit is contained in:
parent
53a3dc601f
commit
88ec0cf557
|
@ -130,7 +130,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
||||||
while i < len(tokens):
|
while i < len(tokens):
|
||||||
token = tokens[i]
|
token = tokens[i]
|
||||||
|
|
||||||
embedding = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
|
embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
|
||||||
|
|
||||||
if embedding is None:
|
if embedding is None:
|
||||||
remade_tokens.append(token)
|
remade_tokens.append(token)
|
||||||
|
@ -142,7 +142,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
||||||
remade_tokens += [0] * emb_len
|
remade_tokens += [0] * emb_len
|
||||||
multipliers += [weight] * emb_len
|
multipliers += [weight] * emb_len
|
||||||
used_custom_terms.append((embedding.name, embedding.checksum()))
|
used_custom_terms.append((embedding.name, embedding.checksum()))
|
||||||
i += emb_len
|
i += embedding_length_in_tokens
|
||||||
|
|
||||||
if len(remade_tokens) > maxlen - 2:
|
if len(remade_tokens) > maxlen - 2:
|
||||||
vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
|
vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
|
||||||
|
@ -213,7 +213,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
||||||
while i < len(tokens):
|
while i < len(tokens):
|
||||||
token = tokens[i]
|
token = tokens[i]
|
||||||
|
|
||||||
embedding = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
|
embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
|
||||||
|
|
||||||
mult_change = self.token_mults.get(token) if opts.enable_emphasis else None
|
mult_change = self.token_mults.get(token) if opts.enable_emphasis else None
|
||||||
if mult_change is not None:
|
if mult_change is not None:
|
||||||
|
@ -229,7 +229,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
||||||
remade_tokens += [0] * emb_len
|
remade_tokens += [0] * emb_len
|
||||||
multipliers += [mult] * emb_len
|
multipliers += [mult] * emb_len
|
||||||
used_custom_terms.append((embedding.name, embedding.checksum()))
|
used_custom_terms.append((embedding.name, embedding.checksum()))
|
||||||
i += emb_len
|
i += embedding_length_in_tokens
|
||||||
|
|
||||||
if len(remade_tokens) > maxlen - 2:
|
if len(remade_tokens) > maxlen - 2:
|
||||||
vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
|
vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
|
||||||
|
|
|
@ -117,24 +117,21 @@ class EmbeddingDatabase:
|
||||||
possible_matches = self.ids_lookup.get(token, None)
|
possible_matches = self.ids_lookup.get(token, None)
|
||||||
|
|
||||||
if possible_matches is None:
|
if possible_matches is None:
|
||||||
return None
|
return None, None
|
||||||
|
|
||||||
for ids, embedding in possible_matches:
|
for ids, embedding in possible_matches:
|
||||||
if tokens[offset:offset + len(ids)] == ids:
|
if tokens[offset:offset + len(ids)] == ids:
|
||||||
return embedding
|
return embedding, len(ids)
|
||||||
|
|
||||||
return None
|
return None, None
|
||||||
|
|
||||||
|
|
||||||
|
def create_embedding(name, num_vectors_per_token, init_text='*'):
|
||||||
def create_embedding(name, num_vectors_per_token):
|
|
||||||
init_text = '*'
|
|
||||||
|
|
||||||
cond_model = shared.sd_model.cond_stage_model
|
cond_model = shared.sd_model.cond_stage_model
|
||||||
embedding_layer = cond_model.wrapped.transformer.text_model.embeddings
|
embedding_layer = cond_model.wrapped.transformer.text_model.embeddings
|
||||||
|
|
||||||
ids = cond_model.tokenizer(init_text, max_length=num_vectors_per_token, return_tensors="pt", add_special_tokens=False)["input_ids"]
|
ids = cond_model.tokenizer(init_text, max_length=num_vectors_per_token, return_tensors="pt", add_special_tokens=False)["input_ids"]
|
||||||
embedded = embedding_layer(ids.to(devices.device)).squeeze(0)
|
embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0)
|
||||||
vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device)
|
vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device)
|
||||||
|
|
||||||
for i in range(num_vectors_per_token):
|
for i in range(num_vectors_per_token):
|
||||||
|
|
|
@ -6,8 +6,8 @@ import modules.textual_inversion.textual_inversion as ti
|
||||||
from modules import sd_hijack, shared
|
from modules import sd_hijack, shared
|
||||||
|
|
||||||
|
|
||||||
def create_embedding(name, nvpt):
|
def create_embedding(name, initialization_text, nvpt):
|
||||||
filename = ti.create_embedding(name, nvpt)
|
filename = ti.create_embedding(name, nvpt, init_text=initialization_text)
|
||||||
|
|
||||||
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
|
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
|
||||||
|
|
||||||
|
|
|
@ -954,6 +954,7 @@ def create_ui(wrap_gradio_gpu_call):
|
||||||
gr.HTML(value="<p style='margin-bottom: 0.7em'>Create a new embedding</p>")
|
gr.HTML(value="<p style='margin-bottom: 0.7em'>Create a new embedding</p>")
|
||||||
|
|
||||||
new_embedding_name = gr.Textbox(label="Name")
|
new_embedding_name = gr.Textbox(label="Name")
|
||||||
|
initialization_text = gr.Textbox(label="Initialization text", value="*")
|
||||||
nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1)
|
nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
|
@ -997,6 +998,7 @@ def create_ui(wrap_gradio_gpu_call):
|
||||||
fn=modules.textual_inversion.ui.create_embedding,
|
fn=modules.textual_inversion.ui.create_embedding,
|
||||||
inputs=[
|
inputs=[
|
||||||
new_embedding_name,
|
new_embedding_name,
|
||||||
|
initialization_text,
|
||||||
nvpt,
|
nvpt,
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
|
|
Loading…
Reference in New Issue
Block a user