Update multihead_attention.py

This commit is contained in:
Mike Ranzinger 2023-04-23 18:08:47 -07:00 committed by GitHub
parent 37b64d41ce
commit a5a94191a1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -121,10 +121,10 @@ class MultiheadAttention(nn.Module):
if key_padding_mask is not None: if key_padding_mask is not None:
# Achieve same result with an additive mask # Achieve same result with an additive mask
attn_mask += key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.float32) * float("-inf") attn_mask = attn_mask + key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.float32) * float("-inf")
if rel_pos is not None: if rel_pos is not None:
attn_mask += rel_pos.view(attn_mask.size()) attn_mask = attn_make + rel_pos.view(attn_mask.size())
if hasattr(F, "scaled_dot_product_attention"): if hasattr(F, "scaled_dot_product_attention"):
attn = F.scaled_dot_product_attention( attn = F.scaled_dot_product_attention(