From 37b64d41ce08c4055b24bc7e4439dcd7bf949374 Mon Sep 17 00:00:00 2001
From: Matthew Smith <usryokousha@gmail.com>
Date: Fri, 31 Mar 2023 11:15:36 +0900
Subject: [PATCH 1/6] incorporated fast attention into attention

---
 torchscale/component/multihead_attention.py | 64 ++++++++++-----------
 1 file changed, 30 insertions(+), 34 deletions(-)

diff --git a/torchscale/component/multihead_attention.py b/torchscale/component/multihead_attention.py
index 392d0e9..191b424 100644
--- a/torchscale/component/multihead_attention.py
+++ b/torchscale/component/multihead_attention.py
@@ -84,31 +84,26 @@ class MultiheadAttention(nn.Module):
         q = self.q_proj(query)
         k = self.k_proj(key)
         v = self.v_proj(value)
-        q *= self.scaling
 
         q = q.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2)
         k = k.view(bsz, src_len, self.num_heads, self.head_dim).transpose(1, 2)
         v = v.view(bsz, src_len, self.num_heads, self.head_dim).transpose(1, 2)
-        q = q.reshape(bsz * self.num_heads, tgt_len, self.head_dim)
-        k = k.reshape(bsz * self.num_heads, src_len, self.head_dim)
-        v = v.reshape(bsz * self.num_heads, src_len, self.head_dim)
+        q = q.reshape(bsz, self.num_heads, tgt_len, self.head_dim)
+        k = k.reshape(bsz, self.num_heads, src_len, self.head_dim)
+        v = v.reshape(bsz, self.num_heads, src_len, self.head_dim)
 
         if incremental_state is not None:
             if "prev_key" in incremental_state:
                 prev_key = incremental_state["prev_key"].view(
-                    bsz * self.num_heads, -1, self.head_dim
+                    bsz, self.num_heads, -1, self.head_dim
                 )
                 prev_value = incremental_state["prev_value"].view(
-                    bsz * self.num_heads, -1, self.head_dim
+                    bsz, self.num_heads, -1, self.head_dim
                 )
                 k = torch.cat([prev_key, k], dim=1)
                 v = torch.cat([prev_value, v], dim=1)
-            incremental_state["prev_key"] = k.view(
-                bsz, self.num_heads, -1, self.head_dim
-            )
-            incremental_state["prev_value"] = v.view(
-                bsz, self.num_heads, -1, self.head_dim
-            )
+            incremental_state["prev_key"] = k
+            incremental_state["prev_value"] = v
             src_len = k.size(1)
 
         if self.xpos is not None:
@@ -116,42 +111,43 @@ class MultiheadAttention(nn.Module):
                 offset = src_len - 1
             else:
                 offset = 0
+            k, q = map(lambda t: t.view(bsz * self.num_heads, -1, self.head_dim), (k, q))
             k = self.xpos(k, offset=0, downscale=True)
             q = self.xpos(q, offset=offset, downscale=False)
-
-        attn_weights = torch.bmm(q, k.transpose(1, 2))
+            k, q = map(lambda t: t.view(bsz, self.num_heads, -1, self.head_dim), (k, q))
 
         if attn_mask is not None:
-            attn_weights = torch.nan_to_num(attn_weights)
             attn_mask = attn_mask.unsqueeze(0)
-            attn_weights += attn_mask
 
         if key_padding_mask is not None:
-            attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
-            attn_weights = attn_weights.masked_fill(
-                key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
-                float("-inf"),
-            )
-            attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+            # Achieve same result with an additive mask
+            attn_mask += key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.float32) * float("-inf")
 
         if rel_pos is not None:
-            rel_pos = rel_pos.view(attn_weights.size())
-            attn_weights = attn_weights + rel_pos
+            attn_mask += rel_pos.view(attn_mask.size())
 
-        attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).type_as(
-            attn_weights
-        )
-        attn_probs = self.dropout_module(attn_weights)
+        if hasattr(F, "scaled_dot_product_attention"):
+            attn = F.scaled_dot_product_attention(
+                q, k, v, attn_mask, self.dropout_module.p
+            )
+            attn_weights = None
+        else:
+            q *= self.scaling
+            q, k, v = map(lambda t: t.view(bsz * self.num_heads, -1, self.head_dim), (q, k, v))
+            attn_weights = torch.bmm(q, k.transpose(1, 2))
 
-        attn = torch.bmm(attn_probs, v)
-        attn = attn.transpose(0, 1).reshape(tgt_len, bsz, embed_dim).transpose(0, 1)
+            attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).type_as(
+                attn_weights
+            )
+            attn_weights = attn_weights.view(
+                bsz, self.num_heads, tgt_len, src_len).transpose(1, 0)
+            attn_probs = self.dropout_module(attn_weights)
+
+            attn = torch.bmm(attn_probs, v)
+            attn = attn.transpose(0, 1).reshape(tgt_len, bsz, embed_dim).transpose(0, 1)
 
         if self.inner_attn_ln is not None:
             attn = self.inner_attn_ln(attn)
 
         attn = self.out_proj(attn)
-        attn_weights = attn_weights.view(
-            bsz, self.num_heads, tgt_len, src_len
-        ).transpose(1, 0)
-
         return attn, attn_weights

From a5a94191a15e0a5dbab577ada8c1bccd410ac44f Mon Sep 17 00:00:00 2001
From: Mike Ranzinger <mikeranzinger@gmail.com>
Date: Sun, 23 Apr 2023 18:08:47 -0700
Subject: [PATCH 2/6] Update multihead_attention.py

---
 torchscale/component/multihead_attention.py | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/torchscale/component/multihead_attention.py b/torchscale/component/multihead_attention.py
index 191b424..7895f8a 100644
--- a/torchscale/component/multihead_attention.py
+++ b/torchscale/component/multihead_attention.py
@@ -121,10 +121,10 @@ class MultiheadAttention(nn.Module):
 
         if key_padding_mask is not None:
             # Achieve same result with an additive mask
-            attn_mask += key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.float32) * float("-inf")
+            attn_mask = attn_mask + key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.float32) * float("-inf")
 
         if rel_pos is not None:
-            attn_mask += rel_pos.view(attn_mask.size())
+            attn_mask = attn_make + rel_pos.view(attn_mask.size())
 
         if hasattr(F, "scaled_dot_product_attention"):
             attn = F.scaled_dot_product_attention(

From 412a1a3878567e608d544ccc6c0c0a7dce128e17 Mon Sep 17 00:00:00 2001
From: Mike Ranzinger <mikeranzinger@gmail.com>
Date: Sun, 23 Apr 2023 18:17:41 -0700
Subject: [PATCH 3/6] Update multihead_attention.py

---
 torchscale/component/multihead_attention.py | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/torchscale/component/multihead_attention.py b/torchscale/component/multihead_attention.py
index 7895f8a..1d736bf 100644
--- a/torchscale/component/multihead_attention.py
+++ b/torchscale/component/multihead_attention.py
@@ -121,7 +121,8 @@ class MultiheadAttention(nn.Module):
 
         if key_padding_mask is not None:
             # Achieve same result with an additive mask
-            attn_mask = attn_mask + key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.float32) * float("-inf")
+            key_padding_mask = torch.where(key_padding_mask, float("-inf"), 0.0)
+            attn_mask = attn_mask + key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.float32)
 
         if rel_pos is not None:
             attn_mask = attn_make + rel_pos.view(attn_mask.size())

From d4a62ccfb512bad8b9ede1eb405e3e6d75934370 Mon Sep 17 00:00:00 2001
From: Mike Ranzinger <mikeranzinger@gmail.com>
Date: Sun, 23 Apr 2023 18:28:08 -0700
Subject: [PATCH 4/6] Update multihead_attention.py

---
 torchscale/component/multihead_attention.py | 2 ++
 1 file changed, 2 insertions(+)

diff --git a/torchscale/component/multihead_attention.py b/torchscale/component/multihead_attention.py
index 1d736bf..9782024 100644
--- a/torchscale/component/multihead_attention.py
+++ b/torchscale/component/multihead_attention.py
@@ -118,6 +118,8 @@ class MultiheadAttention(nn.Module):
 
         if attn_mask is not None:
             attn_mask = attn_mask.unsqueeze(0)
+        else:
+            attn_mask = torch.zeros(1, tgt_len, src_len, dtype=torch.float32, device=k.device)
 
         if key_padding_mask is not None:
             # Achieve same result with an additive mask

From 62cedb9c8fabcf0077d5dbbfff6a44b2af05a321 Mon Sep 17 00:00:00 2001
From: Mike Ranzinger <mikeranzinger@gmail.com>
Date: Sun, 23 Apr 2023 18:45:48 -0700
Subject: [PATCH 5/6] Update multihead_attention.py

---
 torchscale/component/multihead_attention.py | 3 +++
 1 file changed, 3 insertions(+)

diff --git a/torchscale/component/multihead_attention.py b/torchscale/component/multihead_attention.py
index 9782024..4d76384 100644
--- a/torchscale/component/multihead_attention.py
+++ b/torchscale/component/multihead_attention.py
@@ -133,6 +133,9 @@ class MultiheadAttention(nn.Module):
             attn = F.scaled_dot_product_attention(
                 q, k, v, attn_mask, self.dropout_module.p
             )
+            # attn: B,H,T,E (Batch, Heads, Tgt_Len, Dim)
+            # Permute to B,T,H,E, and then flatten to B,T,D
+            attn = attn.permute(0, 2, 1, 3).flatten(2)
             attn_weights = None
         else:
             q *= self.scaling

From 29c6eadb8314cb1dd86f20c81f7614870bfae3e9 Mon Sep 17 00:00:00 2001
From: Mike Ranzinger <mranzinger@nvidia.com>
Date: Tue, 9 May 2023 19:21:25 +0000
Subject: [PATCH 6/6] Masks are now optional, and not created. Fixes required
 to support FlashAttention (e.g. no mask, fp/bf16)

---
 torchscale/architecture/encoder.py          | 14 ++-------
 torchscale/component/multihead_attention.py | 32 ++++++++++++++-------
 2 files changed, 23 insertions(+), 23 deletions(-)

diff --git a/torchscale/architecture/encoder.py b/torchscale/architecture/encoder.py
index 62ab174..103df01 100644
--- a/torchscale/architecture/encoder.py
+++ b/torchscale/architecture/encoder.py
@@ -339,23 +339,13 @@ class Encoder(nn.Module):
     ):
         assert src_tokens is not None or token_embeddings is not None
 
-        if encoder_padding_mask is None:
-            if src_tokens is not None:
-                encoder_padding_mask = torch.zeros_like(
-                    src_tokens, device=src_tokens.device
-                ).bool()
-            else:
-                encoder_padding_mask = torch.zeros(
-                    [token_embeddings.size(0), token_embeddings.size(1)],
-                    device=token_embeddings.device,
-                ).bool()
-
         if multiway_split_position is not None:
             assert self.args.multiway
             self.apply(set_split_position(multiway_split_position))
 
         x, encoder_embedding = self.forward_embedding(src_tokens, token_embeddings, positions)
-        x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x))
+        if encoder_padding_mask is not None:
+            x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x))
 
         encoder_states = []
 
diff --git a/torchscale/component/multihead_attention.py b/torchscale/component/multihead_attention.py
index 4d76384..908f232 100644
--- a/torchscale/component/multihead_attention.py
+++ b/torchscale/component/multihead_attention.py
@@ -2,6 +2,7 @@
 # Licensed under The MIT License [see LICENSE for details]
 
 import math
+from typing import Optional
 
 import torch
 import torch.nn.functional as F
@@ -64,12 +65,12 @@ class MultiheadAttention(nn.Module):
 
     def forward(
         self,
-        query,
-        key,
-        value,
+        query: torch.Tensor,
+        key: torch.Tensor,
+        value: torch.Tensor,
         incremental_state=None,
-        key_padding_mask=None,
-        attn_mask=None,
+        key_padding_mask: Optional[torch.Tensor] = None,
+        attn_mask: Optional[torch.Tensor] = None,
         rel_pos=None,
     ):
         bsz, tgt_len, embed_dim = query.size()
@@ -116,18 +117,27 @@ class MultiheadAttention(nn.Module):
             q = self.xpos(q, offset=offset, downscale=False)
             k, q = map(lambda t: t.view(bsz, self.num_heads, -1, self.head_dim), (k, q))
 
-        if attn_mask is not None:
-            attn_mask = attn_mask.unsqueeze(0)
-        else:
-            attn_mask = torch.zeros(1, tgt_len, src_len, dtype=torch.float32, device=k.device)
+        if attn_mask is not None and attn_mask.ndim != 4:
+            # Add batch and heads
+            attn_mask = attn_mask.reshape(1, 1, *attn_mask.shape).expand(bsz, self.num_heads, -1, -1)
+        # else:
+        #     attn_mask = torch.zeros(1, tgt_len, src_len, dtype=torch.float32, device=k.device)
 
         if key_padding_mask is not None:
             # Achieve same result with an additive mask
             key_padding_mask = torch.where(key_padding_mask, float("-inf"), 0.0)
-            attn_mask = attn_mask + key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.float32)
+            # Add heads and dst_len
+            key_padding_mask = key_padding_mask.reshape(bsz, 1, 1, src_len).to(q.dtype).expand(-1, self.num_heads, tgt_len, -1)
+            if attn_mask is not None:
+                attn_mask = attn_mask + key_padding_mask
+            else:
+                attn_mask = key_padding_mask.expand(-1, self.num_heads, tgt_len, -1)
 
         if rel_pos is not None:
-            attn_mask = attn_make + rel_pos.view(attn_mask.size())
+            if attn_mask is not None:
+                attn_mask = attn_mask + rel_pos.view(attn_mask.size())
+            else:
+                attn_mask = rel_pos.reshape(bsz, self.num_heads, tgt_len, src_len)
 
         if hasattr(F, "scaled_dot_product_attention"):
             attn = F.scaled_dot_product_attention(