diff --git a/torchscale/architecture/encoder.py b/torchscale/architecture/encoder.py index 62ab174..103df01 100644 --- a/torchscale/architecture/encoder.py +++ b/torchscale/architecture/encoder.py @@ -339,23 +339,13 @@ class Encoder(nn.Module): ): assert src_tokens is not None or token_embeddings is not None - if encoder_padding_mask is None: - if src_tokens is not None: - encoder_padding_mask = torch.zeros_like( - src_tokens, device=src_tokens.device - ).bool() - else: - encoder_padding_mask = torch.zeros( - [token_embeddings.size(0), token_embeddings.size(1)], - device=token_embeddings.device, - ).bool() - if multiway_split_position is not None: assert self.args.multiway self.apply(set_split_position(multiway_split_position)) x, encoder_embedding = self.forward_embedding(src_tokens, token_embeddings, positions) - x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x)) + if encoder_padding_mask is not None: + x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x)) encoder_states = [] diff --git a/torchscale/component/multihead_attention.py b/torchscale/component/multihead_attention.py index 33e917e..3992766 100644 --- a/torchscale/component/multihead_attention.py +++ b/torchscale/component/multihead_attention.py @@ -2,6 +2,7 @@ # Licensed under The MIT License [see LICENSE for details] import math +from typing import Optional import torch import torch.nn.functional as F @@ -64,12 +65,12 @@ class MultiheadAttention(nn.Module): def forward( self, - query, - key, - value, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, incremental_state=None, - key_padding_mask=None, - attn_mask=None, + key_padding_mask: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None, rel_pos=None, is_first_step=False, ): @@ -85,31 +86,26 @@ class MultiheadAttention(nn.Module): q = self.q_proj(query) k = self.k_proj(key) v = self.v_proj(value) - q *= self.scaling q = q.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2) k = k.view(bsz, src_len, self.num_heads, self.head_dim).transpose(1, 2) v = v.view(bsz, src_len, self.num_heads, self.head_dim).transpose(1, 2) - q = q.reshape(bsz * self.num_heads, tgt_len, self.head_dim) - k = k.reshape(bsz * self.num_heads, src_len, self.head_dim) - v = v.reshape(bsz * self.num_heads, src_len, self.head_dim) + q = q.reshape(bsz, self.num_heads, tgt_len, self.head_dim) + k = k.reshape(bsz, self.num_heads, src_len, self.head_dim) + v = v.reshape(bsz, self.num_heads, src_len, self.head_dim) if incremental_state is not None: if "prev_key" in incremental_state: prev_key = incremental_state["prev_key"].view( - bsz * self.num_heads, -1, self.head_dim + bsz, self.num_heads, -1, self.head_dim ) prev_value = incremental_state["prev_value"].view( - bsz * self.num_heads, -1, self.head_dim + bsz, self.num_heads, -1, self.head_dim ) k = torch.cat([prev_key, k], dim=1) v = torch.cat([prev_value, v], dim=1) - incremental_state["prev_key"] = k.view( - bsz, self.num_heads, -1, self.head_dim - ) - incremental_state["prev_value"] = v.view( - bsz, self.num_heads, -1, self.head_dim - ) + incremental_state["prev_key"] = k + incremental_state["prev_value"] = v src_len = k.size(1) if self.xpos is not None: @@ -117,42 +113,58 @@ class MultiheadAttention(nn.Module): offset = src_len - 1 else: offset = 0 + k, q = map(lambda t: t.view(bsz * self.num_heads, -1, self.head_dim), (k, q)) k = self.xpos(k, offset=0, downscale=True) q = self.xpos(q, offset=offset, downscale=False) + k, q = map(lambda t: t.view(bsz, self.num_heads, -1, self.head_dim), (k, q)) - attn_weights = torch.bmm(q, k.transpose(1, 2)) - - if attn_mask is not None: - attn_weights = torch.nan_to_num(attn_weights) - attn_mask = attn_mask.unsqueeze(0) - attn_weights += attn_mask + if attn_mask is not None and attn_mask.ndim != 4: + # Add batch and heads + attn_mask = attn_mask.reshape(1, 1, *attn_mask.shape).expand(bsz, self.num_heads, -1, -1) + # else: + # attn_mask = torch.zeros(1, tgt_len, src_len, dtype=torch.float32, device=k.device) if key_padding_mask is not None: - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights.masked_fill( - key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), - float("-inf"), - ) - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + # Achieve same result with an additive mask + key_padding_mask = torch.where(key_padding_mask, float("-inf"), 0.0) + # Add heads and dst_len + key_padding_mask = key_padding_mask.reshape(bsz, 1, 1, src_len).to(q.dtype).expand(-1, self.num_heads, tgt_len, -1) + if attn_mask is not None: + attn_mask = attn_mask + key_padding_mask + else: + attn_mask = key_padding_mask.expand(-1, self.num_heads, tgt_len, -1) if rel_pos is not None: - rel_pos = rel_pos.view(attn_weights.size()) - attn_weights = attn_weights + rel_pos + if attn_mask is not None: + attn_mask = attn_mask + rel_pos.view(attn_mask.size()) + else: + attn_mask = rel_pos.reshape(bsz, self.num_heads, tgt_len, src_len) - attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).type_as( - attn_weights - ) - attn_probs = self.dropout_module(attn_weights) + if hasattr(F, "scaled_dot_product_attention"): + attn = F.scaled_dot_product_attention( + q, k, v, attn_mask, self.dropout_module.p + ) + # attn: B,H,T,E (Batch, Heads, Tgt_Len, Dim) + # Permute to B,T,H,E, and then flatten to B,T,D + attn = attn.permute(0, 2, 1, 3).flatten(2) + attn_weights = None + else: + q *= self.scaling + q, k, v = map(lambda t: t.view(bsz * self.num_heads, -1, self.head_dim), (q, k, v)) + attn_weights = torch.bmm(q, k.transpose(1, 2)) - attn = torch.bmm(attn_probs, v) - attn = attn.transpose(0, 1).reshape(tgt_len, bsz, embed_dim).transpose(0, 1) + attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).type_as( + attn_weights + ) + attn_weights = attn_weights.view( + bsz, self.num_heads, tgt_len, src_len).transpose(1, 0) + attn_probs = self.dropout_module(attn_weights) + + attn = torch.bmm(attn_probs, v) + attn = attn.transpose(0, 1).reshape(tgt_len, bsz, embed_dim).transpose(0, 1) if self.inner_attn_ln is not None: attn = self.inner_attn_ln(attn) attn = self.out_proj(attn) - attn_weights = attn_weights.view( - bsz, self.num_heads, tgt_len, src_len - ).transpose(1, 0) - return attn, attn_weights