Update multihead_attention.py

This commit is contained in:
Mike Ranzinger 2023-04-23 18:45:48 -07:00 committed by GitHub
parent d4a62ccfb5
commit 62cedb9c8f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -133,6 +133,9 @@ class MultiheadAttention(nn.Module):
attn = F.scaled_dot_product_attention(
q, k, v, attn_mask, self.dropout_module.p
)
# attn: B,H,T,E (Batch, Heads, Tgt_Len, Dim)
# Permute to B,T,H,E, and then flatten to B,T,D
attn = attn.permute(0, 2, 1, 3).flatten(2)
attn_weights = None
else:
q *= self.scaling