From 37b64d41ce08c4055b24bc7e4439dcd7bf949374 Mon Sep 17 00:00:00 2001 From: Matthew Smith <usryokousha@gmail.com> Date: Fri, 31 Mar 2023 11:15:36 +0900 Subject: [PATCH 1/6] incorporated fast attention into attention --- torchscale/component/multihead_attention.py | 64 ++++++++++----------- 1 file changed, 30 insertions(+), 34 deletions(-) diff --git a/torchscale/component/multihead_attention.py b/torchscale/component/multihead_attention.py index 392d0e9..191b424 100644 --- a/torchscale/component/multihead_attention.py +++ b/torchscale/component/multihead_attention.py @@ -84,31 +84,26 @@ class MultiheadAttention(nn.Module): q = self.q_proj(query) k = self.k_proj(key) v = self.v_proj(value) - q *= self.scaling q = q.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2) k = k.view(bsz, src_len, self.num_heads, self.head_dim).transpose(1, 2) v = v.view(bsz, src_len, self.num_heads, self.head_dim).transpose(1, 2) - q = q.reshape(bsz * self.num_heads, tgt_len, self.head_dim) - k = k.reshape(bsz * self.num_heads, src_len, self.head_dim) - v = v.reshape(bsz * self.num_heads, src_len, self.head_dim) + q = q.reshape(bsz, self.num_heads, tgt_len, self.head_dim) + k = k.reshape(bsz, self.num_heads, src_len, self.head_dim) + v = v.reshape(bsz, self.num_heads, src_len, self.head_dim) if incremental_state is not None: if "prev_key" in incremental_state: prev_key = incremental_state["prev_key"].view( - bsz * self.num_heads, -1, self.head_dim + bsz, self.num_heads, -1, self.head_dim ) prev_value = incremental_state["prev_value"].view( - bsz * self.num_heads, -1, self.head_dim + bsz, self.num_heads, -1, self.head_dim ) k = torch.cat([prev_key, k], dim=1) v = torch.cat([prev_value, v], dim=1) - incremental_state["prev_key"] = k.view( - bsz, self.num_heads, -1, self.head_dim - ) - incremental_state["prev_value"] = v.view( - bsz, self.num_heads, -1, self.head_dim - ) + incremental_state["prev_key"] = k + incremental_state["prev_value"] = v src_len = k.size(1) if self.xpos is not None: @@ -116,42 +111,43 @@ class MultiheadAttention(nn.Module): offset = src_len - 1 else: offset = 0 + k, q = map(lambda t: t.view(bsz * self.num_heads, -1, self.head_dim), (k, q)) k = self.xpos(k, offset=0, downscale=True) q = self.xpos(q, offset=offset, downscale=False) - - attn_weights = torch.bmm(q, k.transpose(1, 2)) + k, q = map(lambda t: t.view(bsz, self.num_heads, -1, self.head_dim), (k, q)) if attn_mask is not None: - attn_weights = torch.nan_to_num(attn_weights) attn_mask = attn_mask.unsqueeze(0) - attn_weights += attn_mask if key_padding_mask is not None: - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights.masked_fill( - key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), - float("-inf"), - ) - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + # Achieve same result with an additive mask + attn_mask += key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.float32) * float("-inf") if rel_pos is not None: - rel_pos = rel_pos.view(attn_weights.size()) - attn_weights = attn_weights + rel_pos + attn_mask += rel_pos.view(attn_mask.size()) - attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).type_as( - attn_weights - ) - attn_probs = self.dropout_module(attn_weights) + if hasattr(F, "scaled_dot_product_attention"): + attn = F.scaled_dot_product_attention( + q, k, v, attn_mask, self.dropout_module.p + ) + attn_weights = None + else: + q *= self.scaling + q, k, v = map(lambda t: t.view(bsz * self.num_heads, -1, self.head_dim), (q, k, v)) + attn_weights = torch.bmm(q, k.transpose(1, 2)) - attn = torch.bmm(attn_probs, v) - attn = attn.transpose(0, 1).reshape(tgt_len, bsz, embed_dim).transpose(0, 1) + attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).type_as( + attn_weights + ) + attn_weights = attn_weights.view( + bsz, self.num_heads, tgt_len, src_len).transpose(1, 0) + attn_probs = self.dropout_module(attn_weights) + + attn = torch.bmm(attn_probs, v) + attn = attn.transpose(0, 1).reshape(tgt_len, bsz, embed_dim).transpose(0, 1) if self.inner_attn_ln is not None: attn = self.inner_attn_ln(attn) attn = self.out_proj(attn) - attn_weights = attn_weights.view( - bsz, self.num_heads, tgt_len, src_len - ).transpose(1, 0) - return attn, attn_weights From a5a94191a15e0a5dbab577ada8c1bccd410ac44f Mon Sep 17 00:00:00 2001 From: Mike Ranzinger <mikeranzinger@gmail.com> Date: Sun, 23 Apr 2023 18:08:47 -0700 Subject: [PATCH 2/6] Update multihead_attention.py --- torchscale/component/multihead_attention.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchscale/component/multihead_attention.py b/torchscale/component/multihead_attention.py index 191b424..7895f8a 100644 --- a/torchscale/component/multihead_attention.py +++ b/torchscale/component/multihead_attention.py @@ -121,10 +121,10 @@ class MultiheadAttention(nn.Module): if key_padding_mask is not None: # Achieve same result with an additive mask - attn_mask += key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.float32) * float("-inf") + attn_mask = attn_mask + key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.float32) * float("-inf") if rel_pos is not None: - attn_mask += rel_pos.view(attn_mask.size()) + attn_mask = attn_make + rel_pos.view(attn_mask.size()) if hasattr(F, "scaled_dot_product_attention"): attn = F.scaled_dot_product_attention( From 412a1a3878567e608d544ccc6c0c0a7dce128e17 Mon Sep 17 00:00:00 2001 From: Mike Ranzinger <mikeranzinger@gmail.com> Date: Sun, 23 Apr 2023 18:17:41 -0700 Subject: [PATCH 3/6] Update multihead_attention.py --- torchscale/component/multihead_attention.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchscale/component/multihead_attention.py b/torchscale/component/multihead_attention.py index 7895f8a..1d736bf 100644 --- a/torchscale/component/multihead_attention.py +++ b/torchscale/component/multihead_attention.py @@ -121,7 +121,8 @@ class MultiheadAttention(nn.Module): if key_padding_mask is not None: # Achieve same result with an additive mask - attn_mask = attn_mask + key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.float32) * float("-inf") + key_padding_mask = torch.where(key_padding_mask, float("-inf"), 0.0) + attn_mask = attn_mask + key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.float32) if rel_pos is not None: attn_mask = attn_make + rel_pos.view(attn_mask.size()) From d4a62ccfb512bad8b9ede1eb405e3e6d75934370 Mon Sep 17 00:00:00 2001 From: Mike Ranzinger <mikeranzinger@gmail.com> Date: Sun, 23 Apr 2023 18:28:08 -0700 Subject: [PATCH 4/6] Update multihead_attention.py --- torchscale/component/multihead_attention.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torchscale/component/multihead_attention.py b/torchscale/component/multihead_attention.py index 1d736bf..9782024 100644 --- a/torchscale/component/multihead_attention.py +++ b/torchscale/component/multihead_attention.py @@ -118,6 +118,8 @@ class MultiheadAttention(nn.Module): if attn_mask is not None: attn_mask = attn_mask.unsqueeze(0) + else: + attn_mask = torch.zeros(1, tgt_len, src_len, dtype=torch.float32, device=k.device) if key_padding_mask is not None: # Achieve same result with an additive mask From 62cedb9c8fabcf0077d5dbbfff6a44b2af05a321 Mon Sep 17 00:00:00 2001 From: Mike Ranzinger <mikeranzinger@gmail.com> Date: Sun, 23 Apr 2023 18:45:48 -0700 Subject: [PATCH 5/6] Update multihead_attention.py --- torchscale/component/multihead_attention.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/torchscale/component/multihead_attention.py b/torchscale/component/multihead_attention.py index 9782024..4d76384 100644 --- a/torchscale/component/multihead_attention.py +++ b/torchscale/component/multihead_attention.py @@ -133,6 +133,9 @@ class MultiheadAttention(nn.Module): attn = F.scaled_dot_product_attention( q, k, v, attn_mask, self.dropout_module.p ) + # attn: B,H,T,E (Batch, Heads, Tgt_Len, Dim) + # Permute to B,T,H,E, and then flatten to B,T,D + attn = attn.permute(0, 2, 1, 3).flatten(2) attn_weights = None else: q *= self.scaling From 29c6eadb8314cb1dd86f20c81f7614870bfae3e9 Mon Sep 17 00:00:00 2001 From: Mike Ranzinger <mranzinger@nvidia.com> Date: Tue, 9 May 2023 19:21:25 +0000 Subject: [PATCH 6/6] Masks are now optional, and not created. Fixes required to support FlashAttention (e.g. no mask, fp/bf16) --- torchscale/architecture/encoder.py | 14 ++------- torchscale/component/multihead_attention.py | 32 ++++++++++++++------- 2 files changed, 23 insertions(+), 23 deletions(-) diff --git a/torchscale/architecture/encoder.py b/torchscale/architecture/encoder.py index 62ab174..103df01 100644 --- a/torchscale/architecture/encoder.py +++ b/torchscale/architecture/encoder.py @@ -339,23 +339,13 @@ class Encoder(nn.Module): ): assert src_tokens is not None or token_embeddings is not None - if encoder_padding_mask is None: - if src_tokens is not None: - encoder_padding_mask = torch.zeros_like( - src_tokens, device=src_tokens.device - ).bool() - else: - encoder_padding_mask = torch.zeros( - [token_embeddings.size(0), token_embeddings.size(1)], - device=token_embeddings.device, - ).bool() - if multiway_split_position is not None: assert self.args.multiway self.apply(set_split_position(multiway_split_position)) x, encoder_embedding = self.forward_embedding(src_tokens, token_embeddings, positions) - x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x)) + if encoder_padding_mask is not None: + x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x)) encoder_states = [] diff --git a/torchscale/component/multihead_attention.py b/torchscale/component/multihead_attention.py index 4d76384..908f232 100644 --- a/torchscale/component/multihead_attention.py +++ b/torchscale/component/multihead_attention.py @@ -2,6 +2,7 @@ # Licensed under The MIT License [see LICENSE for details] import math +from typing import Optional import torch import torch.nn.functional as F @@ -64,12 +65,12 @@ class MultiheadAttention(nn.Module): def forward( self, - query, - key, - value, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, incremental_state=None, - key_padding_mask=None, - attn_mask=None, + key_padding_mask: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None, rel_pos=None, ): bsz, tgt_len, embed_dim = query.size() @@ -116,18 +117,27 @@ class MultiheadAttention(nn.Module): q = self.xpos(q, offset=offset, downscale=False) k, q = map(lambda t: t.view(bsz, self.num_heads, -1, self.head_dim), (k, q)) - if attn_mask is not None: - attn_mask = attn_mask.unsqueeze(0) - else: - attn_mask = torch.zeros(1, tgt_len, src_len, dtype=torch.float32, device=k.device) + if attn_mask is not None and attn_mask.ndim != 4: + # Add batch and heads + attn_mask = attn_mask.reshape(1, 1, *attn_mask.shape).expand(bsz, self.num_heads, -1, -1) + # else: + # attn_mask = torch.zeros(1, tgt_len, src_len, dtype=torch.float32, device=k.device) if key_padding_mask is not None: # Achieve same result with an additive mask key_padding_mask = torch.where(key_padding_mask, float("-inf"), 0.0) - attn_mask = attn_mask + key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.float32) + # Add heads and dst_len + key_padding_mask = key_padding_mask.reshape(bsz, 1, 1, src_len).to(q.dtype).expand(-1, self.num_heads, tgt_len, -1) + if attn_mask is not None: + attn_mask = attn_mask + key_padding_mask + else: + attn_mask = key_padding_mask.expand(-1, self.num_heads, tgt_len, -1) if rel_pos is not None: - attn_mask = attn_make + rel_pos.view(attn_mask.size()) + if attn_mask is not None: + attn_mask = attn_mask + rel_pos.view(attn_mask.size()) + else: + attn_mask = rel_pos.reshape(bsz, self.num_heads, tgt_len, src_len) if hasattr(F, "scaled_dot_product_attention"): attn = F.scaled_dot_product_attention(