From a5a94191a15e0a5dbab577ada8c1bccd410ac44f Mon Sep 17 00:00:00 2001 From: Mike Ranzinger Date: Sun, 23 Apr 2023 18:08:47 -0700 Subject: [PATCH 1/5] Update multihead_attention.py --- torchscale/component/multihead_attention.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchscale/component/multihead_attention.py b/torchscale/component/multihead_attention.py index 191b424..7895f8a 100644 --- a/torchscale/component/multihead_attention.py +++ b/torchscale/component/multihead_attention.py @@ -121,10 +121,10 @@ class MultiheadAttention(nn.Module): if key_padding_mask is not None: # Achieve same result with an additive mask - attn_mask += key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.float32) * float("-inf") + attn_mask = attn_mask + key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.float32) * float("-inf") if rel_pos is not None: - attn_mask += rel_pos.view(attn_mask.size()) + attn_mask = attn_make + rel_pos.view(attn_mask.size()) if hasattr(F, "scaled_dot_product_attention"): attn = F.scaled_dot_product_attention( From 412a1a3878567e608d544ccc6c0c0a7dce128e17 Mon Sep 17 00:00:00 2001 From: Mike Ranzinger Date: Sun, 23 Apr 2023 18:17:41 -0700 Subject: [PATCH 2/5] Update multihead_attention.py --- torchscale/component/multihead_attention.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchscale/component/multihead_attention.py b/torchscale/component/multihead_attention.py index 7895f8a..1d736bf 100644 --- a/torchscale/component/multihead_attention.py +++ b/torchscale/component/multihead_attention.py @@ -121,7 +121,8 @@ class MultiheadAttention(nn.Module): if key_padding_mask is not None: # Achieve same result with an additive mask - attn_mask = attn_mask + key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.float32) * float("-inf") + key_padding_mask = torch.where(key_padding_mask, float("-inf"), 0.0) + attn_mask = attn_mask + key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.float32) if rel_pos is not None: attn_mask = attn_make + rel_pos.view(attn_mask.size()) From d4a62ccfb512bad8b9ede1eb405e3e6d75934370 Mon Sep 17 00:00:00 2001 From: Mike Ranzinger Date: Sun, 23 Apr 2023 18:28:08 -0700 Subject: [PATCH 3/5] Update multihead_attention.py --- torchscale/component/multihead_attention.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torchscale/component/multihead_attention.py b/torchscale/component/multihead_attention.py index 1d736bf..9782024 100644 --- a/torchscale/component/multihead_attention.py +++ b/torchscale/component/multihead_attention.py @@ -118,6 +118,8 @@ class MultiheadAttention(nn.Module): if attn_mask is not None: attn_mask = attn_mask.unsqueeze(0) + else: + attn_mask = torch.zeros(1, tgt_len, src_len, dtype=torch.float32, device=k.device) if key_padding_mask is not None: # Achieve same result with an additive mask From 62cedb9c8fabcf0077d5dbbfff6a44b2af05a321 Mon Sep 17 00:00:00 2001 From: Mike Ranzinger Date: Sun, 23 Apr 2023 18:45:48 -0700 Subject: [PATCH 4/5] Update multihead_attention.py --- torchscale/component/multihead_attention.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/torchscale/component/multihead_attention.py b/torchscale/component/multihead_attention.py index 9782024..4d76384 100644 --- a/torchscale/component/multihead_attention.py +++ b/torchscale/component/multihead_attention.py @@ -133,6 +133,9 @@ class MultiheadAttention(nn.Module): attn = F.scaled_dot_product_attention( q, k, v, attn_mask, self.dropout_module.p ) + # attn: B,H,T,E (Batch, Heads, Tgt_Len, Dim) + # Permute to B,T,H,E, and then flatten to B,T,D + attn = attn.permute(0, 2, 1, 3).flatten(2) attn_weights = None else: q *= self.scaling From 29c6eadb8314cb1dd86f20c81f7614870bfae3e9 Mon Sep 17 00:00:00 2001 From: Mike Ranzinger Date: Tue, 9 May 2023 19:21:25 +0000 Subject: [PATCH 5/5] Masks are now optional, and not created. Fixes required to support FlashAttention (e.g. no mask, fp/bf16) --- torchscale/architecture/encoder.py | 14 ++------- torchscale/component/multihead_attention.py | 32 ++++++++++++++------- 2 files changed, 23 insertions(+), 23 deletions(-) diff --git a/torchscale/architecture/encoder.py b/torchscale/architecture/encoder.py index 62ab174..103df01 100644 --- a/torchscale/architecture/encoder.py +++ b/torchscale/architecture/encoder.py @@ -339,23 +339,13 @@ class Encoder(nn.Module): ): assert src_tokens is not None or token_embeddings is not None - if encoder_padding_mask is None: - if src_tokens is not None: - encoder_padding_mask = torch.zeros_like( - src_tokens, device=src_tokens.device - ).bool() - else: - encoder_padding_mask = torch.zeros( - [token_embeddings.size(0), token_embeddings.size(1)], - device=token_embeddings.device, - ).bool() - if multiway_split_position is not None: assert self.args.multiway self.apply(set_split_position(multiway_split_position)) x, encoder_embedding = self.forward_embedding(src_tokens, token_embeddings, positions) - x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x)) + if encoder_padding_mask is not None: + x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x)) encoder_states = [] diff --git a/torchscale/component/multihead_attention.py b/torchscale/component/multihead_attention.py index 4d76384..908f232 100644 --- a/torchscale/component/multihead_attention.py +++ b/torchscale/component/multihead_attention.py @@ -2,6 +2,7 @@ # Licensed under The MIT License [see LICENSE for details] import math +from typing import Optional import torch import torch.nn.functional as F @@ -64,12 +65,12 @@ class MultiheadAttention(nn.Module): def forward( self, - query, - key, - value, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, incremental_state=None, - key_padding_mask=None, - attn_mask=None, + key_padding_mask: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None, rel_pos=None, ): bsz, tgt_len, embed_dim = query.size() @@ -116,18 +117,27 @@ class MultiheadAttention(nn.Module): q = self.xpos(q, offset=offset, downscale=False) k, q = map(lambda t: t.view(bsz, self.num_heads, -1, self.head_dim), (k, q)) - if attn_mask is not None: - attn_mask = attn_mask.unsqueeze(0) - else: - attn_mask = torch.zeros(1, tgt_len, src_len, dtype=torch.float32, device=k.device) + if attn_mask is not None and attn_mask.ndim != 4: + # Add batch and heads + attn_mask = attn_mask.reshape(1, 1, *attn_mask.shape).expand(bsz, self.num_heads, -1, -1) + # else: + # attn_mask = torch.zeros(1, tgt_len, src_len, dtype=torch.float32, device=k.device) if key_padding_mask is not None: # Achieve same result with an additive mask key_padding_mask = torch.where(key_padding_mask, float("-inf"), 0.0) - attn_mask = attn_mask + key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.float32) + # Add heads and dst_len + key_padding_mask = key_padding_mask.reshape(bsz, 1, 1, src_len).to(q.dtype).expand(-1, self.num_heads, tgt_len, -1) + if attn_mask is not None: + attn_mask = attn_mask + key_padding_mask + else: + attn_mask = key_padding_mask.expand(-1, self.num_heads, tgt_len, -1) if rel_pos is not None: - attn_mask = attn_make + rel_pos.view(attn_mask.size()) + if attn_mask is not None: + attn_mask = attn_mask + rel_pos.view(attn_mask.size()) + else: + attn_mask = rel_pos.reshape(bsz, self.num_heads, tgt_len, src_len) if hasattr(F, "scaled_dot_product_attention"): attn = F.scaled_dot_product_attention(