diff --git a/torchscale/component/multihead_attention.py b/torchscale/component/multihead_attention.py index 191b424..7895f8a 100644 --- a/torchscale/component/multihead_attention.py +++ b/torchscale/component/multihead_attention.py @@ -121,10 +121,10 @@ class MultiheadAttention(nn.Module): if key_padding_mask is not None: # Achieve same result with an additive mask - attn_mask += key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.float32) * float("-inf") + attn_mask = attn_mask + key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.float32) * float("-inf") if rel_pos is not None: - attn_mask += rel_pos.view(attn_mask.size()) + attn_mask = attn_make + rel_pos.view(attn_mask.size()) if hasattr(F, "scaled_dot_product_attention"): attn = F.scaled_dot_product_attention(