diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index fd57e5c5..3fa06242 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -130,7 +130,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): while i < len(tokens): token = tokens[i] - embedding = self.hijack.embedding_db.find_embedding_at_position(tokens, i) + embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i) if embedding is None: remade_tokens.append(token) @@ -142,7 +142,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): remade_tokens += [0] * emb_len multipliers += [weight] * emb_len used_custom_terms.append((embedding.name, embedding.checksum())) - i += emb_len + i += embedding_length_in_tokens if len(remade_tokens) > maxlen - 2: vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()} @@ -213,7 +213,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): while i < len(tokens): token = tokens[i] - embedding = self.hijack.embedding_db.find_embedding_at_position(tokens, i) + embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i) mult_change = self.token_mults.get(token) if opts.enable_emphasis else None if mult_change is not None: @@ -229,7 +229,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): remade_tokens += [0] * emb_len multipliers += [mult] * emb_len used_custom_terms.append((embedding.name, embedding.checksum())) - i += emb_len + i += embedding_length_in_tokens if len(remade_tokens) > maxlen - 2: vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()} diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index c0baaace..0c50161d 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -117,24 +117,21 @@ class EmbeddingDatabase: possible_matches = self.ids_lookup.get(token, None) if possible_matches is None: - return None + return None, None for ids, embedding in possible_matches: if tokens[offset:offset + len(ids)] == ids: - return embedding + return embedding, len(ids) - return None + return None, None - -def create_embedding(name, num_vectors_per_token): - init_text = '*' - +def create_embedding(name, num_vectors_per_token, init_text='*'): cond_model = shared.sd_model.cond_stage_model embedding_layer = cond_model.wrapped.transformer.text_model.embeddings ids = cond_model.tokenizer(init_text, max_length=num_vectors_per_token, return_tensors="pt", add_special_tokens=False)["input_ids"] - embedded = embedding_layer(ids.to(devices.device)).squeeze(0) + embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0) vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device) for i in range(num_vectors_per_token): diff --git a/modules/textual_inversion/ui.py b/modules/textual_inversion/ui.py index ce3677a9..66c43ffb 100644 --- a/modules/textual_inversion/ui.py +++ b/modules/textual_inversion/ui.py @@ -6,8 +6,8 @@ import modules.textual_inversion.textual_inversion as ti from modules import sd_hijack, shared -def create_embedding(name, nvpt): - filename = ti.create_embedding(name, nvpt) +def create_embedding(name, initialization_text, nvpt): + filename = ti.create_embedding(name, nvpt, init_text=initialization_text) sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() diff --git a/modules/ui.py b/modules/ui.py index 3b81a4f7..eca50df0 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -954,6 +954,7 @@ def create_ui(wrap_gradio_gpu_call): gr.HTML(value="

Create a new embedding

") new_embedding_name = gr.Textbox(label="Name") + initialization_text = gr.Textbox(label="Initialization text", value="*") nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1) with gr.Row(): @@ -997,6 +998,7 @@ def create_ui(wrap_gradio_gpu_call): fn=modules.textual_inversion.ui.create_embedding, inputs=[ new_embedding_name, + initialization_text, nvpt, ], outputs=[