Update multihead_attention.py
This commit is contained in:
parent
37b64d41ce
commit
a5a94191a1
|
@ -121,10 +121,10 @@ class MultiheadAttention(nn.Module):
|
||||||
|
|
||||||
if key_padding_mask is not None:
|
if key_padding_mask is not None:
|
||||||
# Achieve same result with an additive mask
|
# Achieve same result with an additive mask
|
||||||
attn_mask += key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.float32) * float("-inf")
|
attn_mask = attn_mask + key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.float32) * float("-inf")
|
||||||
|
|
||||||
if rel_pos is not None:
|
if rel_pos is not None:
|
||||||
attn_mask += rel_pos.view(attn_mask.size())
|
attn_mask = attn_make + rel_pos.view(attn_mask.size())
|
||||||
|
|
||||||
if hasattr(F, "scaled_dot_product_attention"):
|
if hasattr(F, "scaled_dot_product_attention"):
|
||||||
attn = F.scaled_dot_product_attention(
|
attn = F.scaled_dot_product_attention(
|
||||||
|
|
Loading…
Reference in New Issue
Block a user