From 412a1a3878567e608d544ccc6c0c0a7dce128e17 Mon Sep 17 00:00:00 2001 From: Mike Ranzinger Date: Sun, 23 Apr 2023 18:17:41 -0700 Subject: [PATCH] Update multihead_attention.py --- torchscale/component/multihead_attention.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchscale/component/multihead_attention.py b/torchscale/component/multihead_attention.py index 7895f8a..1d736bf 100644 --- a/torchscale/component/multihead_attention.py +++ b/torchscale/component/multihead_attention.py @@ -121,7 +121,8 @@ class MultiheadAttention(nn.Module): if key_padding_mask is not None: # Achieve same result with an additive mask - attn_mask = attn_mask + key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.float32) * float("-inf") + key_padding_mask = torch.where(key_padding_mask, float("-inf"), 0.0) + attn_mask = attn_mask + key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.float32) if rel_pos is not None: attn_mask = attn_make + rel_pos.view(attn_mask.size())