Use apply_hypernetwork function
This commit is contained in:
parent
574c8e554a
commit
861db783c7
|
@ -202,16 +202,10 @@ def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None):
|
|||
q = self.to_q(x)
|
||||
context = default(context, x)
|
||||
|
||||
hypernetwork = shared.loaded_hypernetwork
|
||||
hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
|
||||
|
||||
if hypernetwork_layers is not None:
|
||||
k = self.to_k(hypernetwork_layers[0](context)) * self.scale
|
||||
v = self.to_v(hypernetwork_layers[1](context))
|
||||
else:
|
||||
k = self.to_k(context) * self.scale
|
||||
v = self.to_v(context)
|
||||
del context, x
|
||||
context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
|
||||
k = self.to_k(context_k) * self.scale
|
||||
v = self.to_v(context_v)
|
||||
del context, context_k, context_v, x
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
||||
r = einsum_op(q, k, v)
|
||||
|
|
Loading…
Reference in New Issue
Block a user