From d4a62ccfb512bad8b9ede1eb405e3e6d75934370 Mon Sep 17 00:00:00 2001 From: Mike Ranzinger Date: Sun, 23 Apr 2023 18:28:08 -0700 Subject: [PATCH] Update multihead_attention.py --- torchscale/component/multihead_attention.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torchscale/component/multihead_attention.py b/torchscale/component/multihead_attention.py index 1d736bf..9782024 100644 --- a/torchscale/component/multihead_attention.py +++ b/torchscale/component/multihead_attention.py @@ -118,6 +118,8 @@ class MultiheadAttention(nn.Module): if attn_mask is not None: attn_mask = attn_mask.unsqueeze(0) + else: + attn_mask = torch.zeros(1, tgt_len, src_len, dtype=torch.float32, device=k.device) if key_padding_mask is not None: # Achieve same result with an additive mask