diff --git a/torchscale/architecture/encoder.py b/torchscale/architecture/encoder.py index 62ab174..103df01 100644 --- a/torchscale/architecture/encoder.py +++ b/torchscale/architecture/encoder.py @@ -339,23 +339,13 @@ class Encoder(nn.Module): ): assert src_tokens is not None or token_embeddings is not None - if encoder_padding_mask is None: - if src_tokens is not None: - encoder_padding_mask = torch.zeros_like( - src_tokens, device=src_tokens.device - ).bool() - else: - encoder_padding_mask = torch.zeros( - [token_embeddings.size(0), token_embeddings.size(1)], - device=token_embeddings.device, - ).bool() - if multiway_split_position is not None: assert self.args.multiway self.apply(set_split_position(multiway_split_position)) x, encoder_embedding = self.forward_embedding(src_tokens, token_embeddings, positions) - x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x)) + if encoder_padding_mask is not None: + x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x)) encoder_states = [] diff --git a/torchscale/component/multihead_attention.py b/torchscale/component/multihead_attention.py index 191b424..908f232 100644 --- a/torchscale/component/multihead_attention.py +++ b/torchscale/component/multihead_attention.py @@ -2,6 +2,7 @@ # Licensed under The MIT License [see LICENSE for details] import math +from typing import Optional import torch import torch.nn.functional as F @@ -64,12 +65,12 @@ class MultiheadAttention(nn.Module): def forward( self, - query, - key, - value, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, incremental_state=None, - key_padding_mask=None, - attn_mask=None, + key_padding_mask: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None, rel_pos=None, ): bsz, tgt_len, embed_dim = query.size() @@ -116,20 +117,35 @@ class MultiheadAttention(nn.Module): q = self.xpos(q, offset=offset, downscale=False) k, q = map(lambda t: t.view(bsz, self.num_heads, -1, self.head_dim), (k, q)) - if attn_mask is not None: - attn_mask = attn_mask.unsqueeze(0) + if attn_mask is not None and attn_mask.ndim != 4: + # Add batch and heads + attn_mask = attn_mask.reshape(1, 1, *attn_mask.shape).expand(bsz, self.num_heads, -1, -1) + # else: + # attn_mask = torch.zeros(1, tgt_len, src_len, dtype=torch.float32, device=k.device) if key_padding_mask is not None: # Achieve same result with an additive mask - attn_mask += key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.float32) * float("-inf") + key_padding_mask = torch.where(key_padding_mask, float("-inf"), 0.0) + # Add heads and dst_len + key_padding_mask = key_padding_mask.reshape(bsz, 1, 1, src_len).to(q.dtype).expand(-1, self.num_heads, tgt_len, -1) + if attn_mask is not None: + attn_mask = attn_mask + key_padding_mask + else: + attn_mask = key_padding_mask.expand(-1, self.num_heads, tgt_len, -1) if rel_pos is not None: - attn_mask += rel_pos.view(attn_mask.size()) + if attn_mask is not None: + attn_mask = attn_mask + rel_pos.view(attn_mask.size()) + else: + attn_mask = rel_pos.reshape(bsz, self.num_heads, tgt_len, src_len) if hasattr(F, "scaled_dot_product_attention"): attn = F.scaled_dot_product_attention( q, k, v, attn_mask, self.dropout_module.p ) + # attn: B,H,T,E (Batch, Heads, Tgt_Len, Dim) + # Permute to B,T,H,E, and then flatten to B,T,D + attn = attn.permute(0, 2, 1, 3).flatten(2) attn_weights = None else: q *= self.scaling