diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 103ace60..66f40367 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -80,23 +80,8 @@ class EmbeddingDatabase: return embedding def get_expected_shape(self): - expected_shape = -1 # initialize with unknown - idx = torch.tensor(0).to(shared.device) - if expected_shape == -1: - try: # matches sd15 signature - first_embedding = shared.sd_model.cond_stage_model.wrapped.transformer.text_model.embeddings.token_embedding.wrapped(idx) - expected_shape = first_embedding.shape[0] - except: - pass - if expected_shape == -1: - try: # matches sd20 signature - first_embedding = shared.sd_model.cond_stage_model.wrapped.model.token_embedding.wrapped(idx) - expected_shape = first_embedding.shape[0] - except: - pass - if expected_shape == -1: - print('Could not determine expected embeddings shape from model') - return expected_shape + vec = shared.sd_model.cond_stage_model.encode_embedding_init_text(",", 1) + return vec.shape[1] def load_textual_inversion_embeddings(self, force_reload = False): mt = os.path.getmtime(self.embeddings_dir) @@ -112,8 +97,6 @@ class EmbeddingDatabase: def process_file(path, filename): name = os.path.splitext(filename)[0] - data = [] - if os.path.splitext(filename.upper())[-1] in ['.PNG', '.WEBP', '.JXL', '.AVIF']: embed_image = Image.open(path) if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text: @@ -150,11 +133,10 @@ class EmbeddingDatabase: embedding.vectors = vec.shape[0] embedding.shape = vec.shape[-1] - if (self.expected_shape == -1) or (self.expected_shape == embedding.shape): + if self.expected_shape == -1 or self.expected_shape == embedding.shape: self.register_embedding(embedding, shared.sd_model) else: self.skipped_embeddings.append(name) - # print('Skipping embedding {name}: shape was {shape} expected {expected}'.format(name = name, shape = embedding.shape, expected = self.expected_shape)) for fn in os.listdir(self.embeddings_dir): try: @@ -169,9 +151,9 @@ class EmbeddingDatabase: print(traceback.format_exc(), file=sys.stderr) continue - print("Textual inversion embeddings {num} loaded: {val}".format(num = len(self.word_embeddings), val = ', '.join(self.word_embeddings.keys()))) - if (len(self.skipped_embeddings) > 0): - print("Textual inversion embeddings {num} skipped: {val}".format(num = len(self.skipped_embeddings), val = ', '.join(self.skipped_embeddings))) + print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}") + if len(self.skipped_embeddings) > 0: + print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings)}") def find_embedding_at_position(self, tokens, offset): token = tokens[offset]