Update multihead_attention.py
This commit is contained in:
parent
37b64d41ce
commit
a5a94191a1
|
@ -121,10 +121,10 @@ class MultiheadAttention(nn.Module):
|
|||
|
||||
if key_padding_mask is not None:
|
||||
# Achieve same result with an additive mask
|
||||
attn_mask += key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.float32) * float("-inf")
|
||||
attn_mask = attn_mask + key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.float32) * float("-inf")
|
||||
|
||||
if rel_pos is not None:
|
||||
attn_mask += rel_pos.view(attn_mask.size())
|
||||
attn_mask = attn_make + rel_pos.view(attn_mask.size())
|
||||
|
||||
if hasattr(F, "scaled_dot_product_attention"):
|
||||
attn = F.scaled_dot_product_attention(
|
||||
|
|
Loading…
Reference in New Issue
Block a user