Update multihead_attention.py
This commit is contained in:
parent
a5a94191a1
commit
412a1a3878
|
@ -121,7 +121,8 @@ class MultiheadAttention(nn.Module):
|
|||
|
||||
if key_padding_mask is not None:
|
||||
# Achieve same result with an additive mask
|
||||
attn_mask = attn_mask + key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.float32) * float("-inf")
|
||||
key_padding_mask = torch.where(key_padding_mask, float("-inf"), 0.0)
|
||||
attn_mask = attn_mask + key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.float32)
|
||||
|
||||
if rel_pos is not None:
|
||||
attn_mask = attn_make + rel_pos.view(attn_mask.size())
|
||||
|
|
Loading…
Reference in New Issue
Block a user