Update multihead_attention.py

This commit is contained in:
Mike Ranzinger 2023-04-23 18:17:41 -07:00 committed by GitHub
parent a5a94191a1
commit 412a1a3878
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -121,7 +121,8 @@ class MultiheadAttention(nn.Module):
if key_padding_mask is not None:
# Achieve same result with an additive mask
attn_mask = attn_mask + key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.float32) * float("-inf")
key_padding_mask = torch.where(key_padding_mask, float("-inf"), 0.0)
attn_mask = attn_mask + key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.float32)
if rel_pos is not None:
attn_mask = attn_make + rel_pos.view(attn_mask.size())