Change sub-quad chunk threshold to use percentage

This commit is contained in:
brkirch 2023-01-06 01:01:51 -05:00
parent b119815333
commit b95a4c0ce5
2 changed files with 10 additions and 10 deletions

View File

@ -233,7 +233,7 @@ def sub_quad_attention_forward(self, x, context=None, mask=None):
k = k.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
v = v.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
x = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold_bytes=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training)
x = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training)
x = x.unflatten(0, (-1, h)).transpose(1,2).flatten(start_dim=2)
@ -243,20 +243,20 @@ def sub_quad_attention_forward(self, x, context=None, mask=None):
return x
def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_size_min=None, chunk_threshold_bytes=None, use_checkpoint=True):
def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_size_min=None, chunk_threshold=None, use_checkpoint=True):
bytes_per_token = torch.finfo(q.dtype).bits//8
batch_x_heads, q_tokens, _ = q.shape
_, k_tokens, _ = k.shape
qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
available_vram = int(get_available_vram() * 0.9) if q.device.type == 'mps' else int(get_available_vram() * 0.7)
if chunk_threshold_bytes is None:
chunk_threshold_bytes = available_vram
elif chunk_threshold_bytes == 0:
if chunk_threshold is None:
chunk_threshold_bytes = int(get_available_vram() * 0.9) if q.device.type == 'mps' else int(get_available_vram() * 0.7)
elif chunk_threshold == 0:
chunk_threshold_bytes = None
else:
chunk_threshold_bytes = int(0.01 * chunk_threshold * get_available_vram())
if kv_chunk_size_min is None:
if kv_chunk_size_min is None and chunk_threshold_bytes is not None:
kv_chunk_size_min = chunk_threshold_bytes // (batch_x_heads * bytes_per_token * (k.shape[2] + v.shape[2]))
elif kv_chunk_size_min == 0:
kv_chunk_size_min = None
@ -382,7 +382,7 @@ def sub_quad_attnblock_forward(self, x):
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
out = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold_bytes=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training)
out = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training)
out = rearrange(out, 'b (h w) c -> b c h w', h=h)
out = self.proj_out(out)
return x + out

View File

@ -59,7 +59,7 @@ parser.add_argument("--opt-split-attention", action='store_true', help="force-en
parser.add_argument("--opt-sub-quad-attention", action='store_true', help="enable memory efficient sub-quadratic cross-attention layer optimization")
parser.add_argument("--sub-quad-q-chunk-size", type=int, help="query chunk size for the sub-quadratic cross-attention layer optimization to use", default=1024)
parser.add_argument("--sub-quad-kv-chunk-size", type=int, help="kv chunk size for the sub-quadratic cross-attention layer optimization to use", default=None)
parser.add_argument("--sub-quad-chunk-threshold", type=int, help="the size threshold in bytes for the sub-quadratic cross-attention layer optimization to use chunking", default=None)
parser.add_argument("--sub-quad-chunk-threshold", type=int, help="the percentage of VRAM threshold for the sub-quadratic cross-attention layer optimization to use chunking", default=None)
parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.")
parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")