diff --git a/torchscale/component/multihead_attention.py b/torchscale/component/multihead_attention.py index 1d736bf..9782024 100644 --- a/torchscale/component/multihead_attention.py +++ b/torchscale/component/multihead_attention.py @@ -118,6 +118,8 @@ class MultiheadAttention(nn.Module): if attn_mask is not None: attn_mask = attn_mask.unsqueeze(0) + else: + attn_mask = torch.zeros(1, tgt_len, src_len, dtype=torch.float32, device=k.device) if key_padding_mask is not None: # Achieve same result with an additive mask