From d52a80f7f7da160c73afd067c8f1bf491391f994 Mon Sep 17 00:00:00 2001 From: Shondoit Date: Thu, 12 Jan 2023 09:22:29 +0100 Subject: [PATCH] Allow creation of zero vectors for TI --- modules/textual_inversion/textual_inversion.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index b915b091..853246a6 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -248,11 +248,14 @@ def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'): with devices.autocast(): cond_model([""]) # will send cond model to GPU if lowvram/medvram is active - embedded = cond_model.encode_embedding_init_text(init_text, num_vectors_per_token) + #cond_model expects at least some text, so we provide '*' as backup. + embedded = cond_model.encode_embedding_init_text(init_text or '*', num_vectors_per_token) vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device) - for i in range(num_vectors_per_token): - vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token] + #Only copy if we provided an init_text, otherwise keep vectors as zeros + if init_text: + for i in range(num_vectors_per_token): + vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token] # Remove illegal characters from name. name = "".join( x for x in name if (x.isalnum() or x in "._- "))