incorporated fast attention into attention
This commit is contained in:
parent
4ae3b248ee
commit
37b64d41ce
|
@ -84,31 +84,26 @@ class MultiheadAttention(nn.Module):
|
|||
q = self.q_proj(query)
|
||||
k = self.k_proj(key)
|
||||
v = self.v_proj(value)
|
||||
q *= self.scaling
|
||||
|
||||
q = q.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
k = k.view(bsz, src_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
v = v.view(bsz, src_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
q = q.reshape(bsz * self.num_heads, tgt_len, self.head_dim)
|
||||
k = k.reshape(bsz * self.num_heads, src_len, self.head_dim)
|
||||
v = v.reshape(bsz * self.num_heads, src_len, self.head_dim)
|
||||
q = q.reshape(bsz, self.num_heads, tgt_len, self.head_dim)
|
||||
k = k.reshape(bsz, self.num_heads, src_len, self.head_dim)
|
||||
v = v.reshape(bsz, self.num_heads, src_len, self.head_dim)
|
||||
|
||||
if incremental_state is not None:
|
||||
if "prev_key" in incremental_state:
|
||||
prev_key = incremental_state["prev_key"].view(
|
||||
bsz * self.num_heads, -1, self.head_dim
|
||||
bsz, self.num_heads, -1, self.head_dim
|
||||
)
|
||||
prev_value = incremental_state["prev_value"].view(
|
||||
bsz * self.num_heads, -1, self.head_dim
|
||||
bsz, self.num_heads, -1, self.head_dim
|
||||
)
|
||||
k = torch.cat([prev_key, k], dim=1)
|
||||
v = torch.cat([prev_value, v], dim=1)
|
||||
incremental_state["prev_key"] = k.view(
|
||||
bsz, self.num_heads, -1, self.head_dim
|
||||
)
|
||||
incremental_state["prev_value"] = v.view(
|
||||
bsz, self.num_heads, -1, self.head_dim
|
||||
)
|
||||
incremental_state["prev_key"] = k
|
||||
incremental_state["prev_value"] = v
|
||||
src_len = k.size(1)
|
||||
|
||||
if self.xpos is not None:
|
||||
|
@ -116,42 +111,43 @@ class MultiheadAttention(nn.Module):
|
|||
offset = src_len - 1
|
||||
else:
|
||||
offset = 0
|
||||
k, q = map(lambda t: t.view(bsz * self.num_heads, -1, self.head_dim), (k, q))
|
||||
k = self.xpos(k, offset=0, downscale=True)
|
||||
q = self.xpos(q, offset=offset, downscale=False)
|
||||
|
||||
attn_weights = torch.bmm(q, k.transpose(1, 2))
|
||||
k, q = map(lambda t: t.view(bsz, self.num_heads, -1, self.head_dim), (k, q))
|
||||
|
||||
if attn_mask is not None:
|
||||
attn_weights = torch.nan_to_num(attn_weights)
|
||||
attn_mask = attn_mask.unsqueeze(0)
|
||||
attn_weights += attn_mask
|
||||
|
||||
if key_padding_mask is not None:
|
||||
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||
attn_weights = attn_weights.masked_fill(
|
||||
key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
|
||||
float("-inf"),
|
||||
)
|
||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||
# Achieve same result with an additive mask
|
||||
attn_mask += key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.float32) * float("-inf")
|
||||
|
||||
if rel_pos is not None:
|
||||
rel_pos = rel_pos.view(attn_weights.size())
|
||||
attn_weights = attn_weights + rel_pos
|
||||
attn_mask += rel_pos.view(attn_mask.size())
|
||||
|
||||
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).type_as(
|
||||
attn_weights
|
||||
)
|
||||
attn_probs = self.dropout_module(attn_weights)
|
||||
if hasattr(F, "scaled_dot_product_attention"):
|
||||
attn = F.scaled_dot_product_attention(
|
||||
q, k, v, attn_mask, self.dropout_module.p
|
||||
)
|
||||
attn_weights = None
|
||||
else:
|
||||
q *= self.scaling
|
||||
q, k, v = map(lambda t: t.view(bsz * self.num_heads, -1, self.head_dim), (q, k, v))
|
||||
attn_weights = torch.bmm(q, k.transpose(1, 2))
|
||||
|
||||
attn = torch.bmm(attn_probs, v)
|
||||
attn = attn.transpose(0, 1).reshape(tgt_len, bsz, embed_dim).transpose(0, 1)
|
||||
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).type_as(
|
||||
attn_weights
|
||||
)
|
||||
attn_weights = attn_weights.view(
|
||||
bsz, self.num_heads, tgt_len, src_len).transpose(1, 0)
|
||||
attn_probs = self.dropout_module(attn_weights)
|
||||
|
||||
attn = torch.bmm(attn_probs, v)
|
||||
attn = attn.transpose(0, 1).reshape(tgt_len, bsz, embed_dim).transpose(0, 1)
|
||||
|
||||
if self.inner_attn_ln is not None:
|
||||
attn = self.inner_attn_ln(attn)
|
||||
|
||||
attn = self.out_proj(attn)
|
||||
attn_weights = attn_weights.view(
|
||||
bsz, self.num_heads, tgt_len, src_len
|
||||
).transpose(1, 0)
|
||||
|
||||
return attn, attn_weights
|
||||
|
|
Loading…
Reference in New Issue
Block a user