Update multihead_attention.py

This commit is contained in:
Mike Ranzinger 2023-04-23 18:28:08 -07:00 committed by GitHub
parent 412a1a3878
commit d4a62ccfb5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -118,6 +118,8 @@ class MultiheadAttention(nn.Module):
if attn_mask is not None:
attn_mask = attn_mask.unsqueeze(0)
else:
attn_mask = torch.zeros(1, tgt_len, src_len, dtype=torch.float32, device=k.device)
if key_padding_mask is not None:
# Achieve same result with an additive mask