Allow creation of zero vectors for TI
This commit is contained in:
parent
0b8911d883
commit
d52a80f7f7
|
@ -248,11 +248,14 @@ def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'):
|
||||||
with devices.autocast():
|
with devices.autocast():
|
||||||
cond_model([""]) # will send cond model to GPU if lowvram/medvram is active
|
cond_model([""]) # will send cond model to GPU if lowvram/medvram is active
|
||||||
|
|
||||||
embedded = cond_model.encode_embedding_init_text(init_text, num_vectors_per_token)
|
#cond_model expects at least some text, so we provide '*' as backup.
|
||||||
|
embedded = cond_model.encode_embedding_init_text(init_text or '*', num_vectors_per_token)
|
||||||
vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device)
|
vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device)
|
||||||
|
|
||||||
for i in range(num_vectors_per_token):
|
#Only copy if we provided an init_text, otherwise keep vectors as zeros
|
||||||
vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token]
|
if init_text:
|
||||||
|
for i in range(num_vectors_per_token):
|
||||||
|
vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token]
|
||||||
|
|
||||||
# Remove illegal characters from name.
|
# Remove illegal characters from name.
|
||||||
name = "".join( x for x in name if (x.isalnum() or x in "._- "))
|
name = "".join( x for x in name if (x.isalnum() or x in "._- "))
|
||||||
|
|
Loading…
Reference in New Issue
Block a user