diff --git a/torchscale/component/multihead_attention.py b/torchscale/component/multihead_attention.py index 9782024..4d76384 100644 --- a/torchscale/component/multihead_attention.py +++ b/torchscale/component/multihead_attention.py @@ -133,6 +133,9 @@ class MultiheadAttention(nn.Module): attn = F.scaled_dot_product_attention( q, k, v, attn_mask, self.dropout_module.p ) + # attn: B,H,T,E (Batch, Heads, Tgt_Len, Dim) + # Permute to B,T,H,E, and then flatten to B,T,D + attn = attn.permute(0, 2, 1, 3).flatten(2) attn_weights = None else: q *= self.scaling