Update multihead_attention.py
This commit is contained in:
parent
412a1a3878
commit
d4a62ccfb5
|
@ -118,6 +118,8 @@ class MultiheadAttention(nn.Module):
|
|||
|
||||
if attn_mask is not None:
|
||||
attn_mask = attn_mask.unsqueeze(0)
|
||||
else:
|
||||
attn_mask = torch.zeros(1, tgt_len, src_len, dtype=torch.float32, device=k.device)
|
||||
|
||||
if key_padding_mask is not None:
|
||||
# Achieve same result with an additive mask
|
||||
|
|
Loading…
Reference in New Issue
Block a user