From a5a94191a15e0a5dbab577ada8c1bccd410ac44f Mon Sep 17 00:00:00 2001 From: Mike Ranzinger Date: Sun, 23 Apr 2023 18:08:47 -0700 Subject: [PATCH] Update multihead_attention.py --- torchscale/component/multihead_attention.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchscale/component/multihead_attention.py b/torchscale/component/multihead_attention.py index 191b424..7895f8a 100644 --- a/torchscale/component/multihead_attention.py +++ b/torchscale/component/multihead_attention.py @@ -121,10 +121,10 @@ class MultiheadAttention(nn.Module): if key_padding_mask is not None: # Achieve same result with an additive mask - attn_mask += key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.float32) * float("-inf") + attn_mask = attn_mask + key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.float32) * float("-inf") if rel_pos is not None: - attn_mask += rel_pos.view(attn_mask.size()) + attn_mask = attn_make + rel_pos.view(attn_mask.size()) if hasattr(F, "scaled_dot_product_attention"): attn = F.scaled_dot_product_attention(