diff --git a/torchscale/component/multihead_attention.py b/torchscale/component/multihead_attention.py index 392d0e9..191b424 100644 --- a/torchscale/component/multihead_attention.py +++ b/torchscale/component/multihead_attention.py @@ -84,31 +84,26 @@ class MultiheadAttention(nn.Module): q = self.q_proj(query) k = self.k_proj(key) v = self.v_proj(value) - q *= self.scaling q = q.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2) k = k.view(bsz, src_len, self.num_heads, self.head_dim).transpose(1, 2) v = v.view(bsz, src_len, self.num_heads, self.head_dim).transpose(1, 2) - q = q.reshape(bsz * self.num_heads, tgt_len, self.head_dim) - k = k.reshape(bsz * self.num_heads, src_len, self.head_dim) - v = v.reshape(bsz * self.num_heads, src_len, self.head_dim) + q = q.reshape(bsz, self.num_heads, tgt_len, self.head_dim) + k = k.reshape(bsz, self.num_heads, src_len, self.head_dim) + v = v.reshape(bsz, self.num_heads, src_len, self.head_dim) if incremental_state is not None: if "prev_key" in incremental_state: prev_key = incremental_state["prev_key"].view( - bsz * self.num_heads, -1, self.head_dim + bsz, self.num_heads, -1, self.head_dim ) prev_value = incremental_state["prev_value"].view( - bsz * self.num_heads, -1, self.head_dim + bsz, self.num_heads, -1, self.head_dim ) k = torch.cat([prev_key, k], dim=1) v = torch.cat([prev_value, v], dim=1) - incremental_state["prev_key"] = k.view( - bsz, self.num_heads, -1, self.head_dim - ) - incremental_state["prev_value"] = v.view( - bsz, self.num_heads, -1, self.head_dim - ) + incremental_state["prev_key"] = k + incremental_state["prev_value"] = v src_len = k.size(1) if self.xpos is not None: @@ -116,42 +111,43 @@ class MultiheadAttention(nn.Module): offset = src_len - 1 else: offset = 0 + k, q = map(lambda t: t.view(bsz * self.num_heads, -1, self.head_dim), (k, q)) k = self.xpos(k, offset=0, downscale=True) q = self.xpos(q, offset=offset, downscale=False) - - attn_weights = torch.bmm(q, k.transpose(1, 2)) + k, q = map(lambda t: t.view(bsz, self.num_heads, -1, self.head_dim), (k, q)) if attn_mask is not None: - attn_weights = torch.nan_to_num(attn_weights) attn_mask = attn_mask.unsqueeze(0) - attn_weights += attn_mask if key_padding_mask is not None: - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights.masked_fill( - key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), - float("-inf"), - ) - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + # Achieve same result with an additive mask + attn_mask += key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.float32) * float("-inf") if rel_pos is not None: - rel_pos = rel_pos.view(attn_weights.size()) - attn_weights = attn_weights + rel_pos + attn_mask += rel_pos.view(attn_mask.size()) - attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).type_as( - attn_weights - ) - attn_probs = self.dropout_module(attn_weights) + if hasattr(F, "scaled_dot_product_attention"): + attn = F.scaled_dot_product_attention( + q, k, v, attn_mask, self.dropout_module.p + ) + attn_weights = None + else: + q *= self.scaling + q, k, v = map(lambda t: t.view(bsz * self.num_heads, -1, self.head_dim), (q, k, v)) + attn_weights = torch.bmm(q, k.transpose(1, 2)) - attn = torch.bmm(attn_probs, v) - attn = attn.transpose(0, 1).reshape(tgt_len, bsz, embed_dim).transpose(0, 1) + attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).type_as( + attn_weights + ) + attn_weights = attn_weights.view( + bsz, self.num_heads, tgt_len, src_len).transpose(1, 0) + attn_probs = self.dropout_module(attn_weights) + + attn = torch.bmm(attn_probs, v) + attn = attn.transpose(0, 1).reshape(tgt_len, bsz, embed_dim).transpose(0, 1) if self.inner_attn_ln is not None: attn = self.inner_attn_ln(attn) attn = self.out_proj(attn) - attn_weights = attn_weights.view( - bsz, self.num_heads, tgt_len, src_len - ).transpose(1, 0) - return attn, attn_weights