diff --git a/torchscale/component/multihead_attention.py b/torchscale/component/multihead_attention.py
index 191b424..7895f8a 100644
--- a/torchscale/component/multihead_attention.py
+++ b/torchscale/component/multihead_attention.py
@@ -121,10 +121,10 @@ class MultiheadAttention(nn.Module):
 
         if key_padding_mask is not None:
             # Achieve same result with an additive mask
-            attn_mask += key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.float32) * float("-inf")
+            attn_mask = attn_mask + key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.float32) * float("-inf")
 
         if rel_pos is not None:
-            attn_mask += rel_pos.view(attn_mask.size())
+            attn_mask = attn_make + rel_pos.view(attn_mask.size())
 
         if hasattr(F, "scaled_dot_product_attention"):
             attn = F.scaled_dot_product_attention(