Masks are now optional, and not created. Fixes required to support FlashAttention (e.g. no mask, fp/bf16)

This commit is contained in:
Mike Ranzinger 2023-05-09 19:21:25 +00:00
parent 62cedb9c8f
commit 29c6eadb83
2 changed files with 23 additions and 23 deletions

View File

@ -339,23 +339,13 @@ class Encoder(nn.Module):
):
assert src_tokens is not None or token_embeddings is not None
if encoder_padding_mask is None:
if src_tokens is not None:
encoder_padding_mask = torch.zeros_like(
src_tokens, device=src_tokens.device
).bool()
else:
encoder_padding_mask = torch.zeros(
[token_embeddings.size(0), token_embeddings.size(1)],
device=token_embeddings.device,
).bool()
if multiway_split_position is not None:
assert self.args.multiway
self.apply(set_split_position(multiway_split_position))
x, encoder_embedding = self.forward_embedding(src_tokens, token_embeddings, positions)
x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x))
if encoder_padding_mask is not None:
x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x))
encoder_states = []

View File

@ -2,6 +2,7 @@
# Licensed under The MIT License [see LICENSE for details]
import math
from typing import Optional
import torch
import torch.nn.functional as F
@ -64,12 +65,12 @@ class MultiheadAttention(nn.Module):
def forward(
self,
query,
key,
value,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
incremental_state=None,
key_padding_mask=None,
attn_mask=None,
key_padding_mask: Optional[torch.Tensor] = None,
attn_mask: Optional[torch.Tensor] = None,
rel_pos=None,
):
bsz, tgt_len, embed_dim = query.size()
@ -116,18 +117,27 @@ class MultiheadAttention(nn.Module):
q = self.xpos(q, offset=offset, downscale=False)
k, q = map(lambda t: t.view(bsz, self.num_heads, -1, self.head_dim), (k, q))
if attn_mask is not None:
attn_mask = attn_mask.unsqueeze(0)
else:
attn_mask = torch.zeros(1, tgt_len, src_len, dtype=torch.float32, device=k.device)
if attn_mask is not None and attn_mask.ndim != 4:
# Add batch and heads
attn_mask = attn_mask.reshape(1, 1, *attn_mask.shape).expand(bsz, self.num_heads, -1, -1)
# else:
# attn_mask = torch.zeros(1, tgt_len, src_len, dtype=torch.float32, device=k.device)
if key_padding_mask is not None:
# Achieve same result with an additive mask
key_padding_mask = torch.where(key_padding_mask, float("-inf"), 0.0)
attn_mask = attn_mask + key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.float32)
# Add heads and dst_len
key_padding_mask = key_padding_mask.reshape(bsz, 1, 1, src_len).to(q.dtype).expand(-1, self.num_heads, tgt_len, -1)
if attn_mask is not None:
attn_mask = attn_mask + key_padding_mask
else:
attn_mask = key_padding_mask.expand(-1, self.num_heads, tgt_len, -1)
if rel_pos is not None:
attn_mask = attn_make + rel_pos.view(attn_mask.size())
if attn_mask is not None:
attn_mask = attn_mask + rel_pos.view(attn_mask.size())
else:
attn_mask = rel_pos.reshape(bsz, self.num_heads, tgt_len, src_len)
if hasattr(F, "scaled_dot_product_attention"):
attn = F.scaled_dot_product_attention(