Merge dd69dcb5e9
into 258eda3308
This commit is contained in:
commit
5d16e572d5
|
@ -339,22 +339,12 @@ class Encoder(nn.Module):
|
||||||
):
|
):
|
||||||
assert src_tokens is not None or token_embeddings is not None
|
assert src_tokens is not None or token_embeddings is not None
|
||||||
|
|
||||||
if encoder_padding_mask is None:
|
|
||||||
if src_tokens is not None:
|
|
||||||
encoder_padding_mask = torch.zeros_like(
|
|
||||||
src_tokens, device=src_tokens.device
|
|
||||||
).bool()
|
|
||||||
else:
|
|
||||||
encoder_padding_mask = torch.zeros(
|
|
||||||
[token_embeddings.size(0), token_embeddings.size(1)],
|
|
||||||
device=token_embeddings.device,
|
|
||||||
).bool()
|
|
||||||
|
|
||||||
if multiway_split_position is not None:
|
if multiway_split_position is not None:
|
||||||
assert self.args.multiway
|
assert self.args.multiway
|
||||||
self.apply(set_split_position(multiway_split_position))
|
self.apply(set_split_position(multiway_split_position))
|
||||||
|
|
||||||
x, encoder_embedding = self.forward_embedding(src_tokens, token_embeddings, positions)
|
x, encoder_embedding = self.forward_embedding(src_tokens, token_embeddings, positions)
|
||||||
|
if encoder_padding_mask is not None:
|
||||||
x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x))
|
x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x))
|
||||||
|
|
||||||
encoder_states = []
|
encoder_states = []
|
||||||
|
|
|
@ -2,6 +2,7 @@
|
||||||
# Licensed under The MIT License [see LICENSE for details]
|
# Licensed under The MIT License [see LICENSE for details]
|
||||||
|
|
||||||
import math
|
import math
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
@ -64,12 +65,12 @@ class MultiheadAttention(nn.Module):
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
query,
|
query: torch.Tensor,
|
||||||
key,
|
key: torch.Tensor,
|
||||||
value,
|
value: torch.Tensor,
|
||||||
incremental_state=None,
|
incremental_state=None,
|
||||||
key_padding_mask=None,
|
key_padding_mask: Optional[torch.Tensor] = None,
|
||||||
attn_mask=None,
|
attn_mask: Optional[torch.Tensor] = None,
|
||||||
rel_pos=None,
|
rel_pos=None,
|
||||||
is_first_step=False,
|
is_first_step=False,
|
||||||
):
|
):
|
||||||
|
@ -85,31 +86,26 @@ class MultiheadAttention(nn.Module):
|
||||||
q = self.q_proj(query)
|
q = self.q_proj(query)
|
||||||
k = self.k_proj(key)
|
k = self.k_proj(key)
|
||||||
v = self.v_proj(value)
|
v = self.v_proj(value)
|
||||||
q *= self.scaling
|
|
||||||
|
|
||||||
q = q.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2)
|
q = q.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
k = k.view(bsz, src_len, self.num_heads, self.head_dim).transpose(1, 2)
|
k = k.view(bsz, src_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
v = v.view(bsz, src_len, self.num_heads, self.head_dim).transpose(1, 2)
|
v = v.view(bsz, src_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
q = q.reshape(bsz * self.num_heads, tgt_len, self.head_dim)
|
q = q.reshape(bsz, self.num_heads, tgt_len, self.head_dim)
|
||||||
k = k.reshape(bsz * self.num_heads, src_len, self.head_dim)
|
k = k.reshape(bsz, self.num_heads, src_len, self.head_dim)
|
||||||
v = v.reshape(bsz * self.num_heads, src_len, self.head_dim)
|
v = v.reshape(bsz, self.num_heads, src_len, self.head_dim)
|
||||||
|
|
||||||
if incremental_state is not None:
|
if incremental_state is not None:
|
||||||
if "prev_key" in incremental_state:
|
if "prev_key" in incremental_state:
|
||||||
prev_key = incremental_state["prev_key"].view(
|
prev_key = incremental_state["prev_key"].view(
|
||||||
bsz * self.num_heads, -1, self.head_dim
|
bsz, self.num_heads, -1, self.head_dim
|
||||||
)
|
)
|
||||||
prev_value = incremental_state["prev_value"].view(
|
prev_value = incremental_state["prev_value"].view(
|
||||||
bsz * self.num_heads, -1, self.head_dim
|
bsz, self.num_heads, -1, self.head_dim
|
||||||
)
|
)
|
||||||
k = torch.cat([prev_key, k], dim=1)
|
k = torch.cat([prev_key, k], dim=1)
|
||||||
v = torch.cat([prev_value, v], dim=1)
|
v = torch.cat([prev_value, v], dim=1)
|
||||||
incremental_state["prev_key"] = k.view(
|
incremental_state["prev_key"] = k
|
||||||
bsz, self.num_heads, -1, self.head_dim
|
incremental_state["prev_value"] = v
|
||||||
)
|
|
||||||
incremental_state["prev_value"] = v.view(
|
|
||||||
bsz, self.num_heads, -1, self.head_dim
|
|
||||||
)
|
|
||||||
src_len = k.size(1)
|
src_len = k.size(1)
|
||||||
|
|
||||||
if self.xpos is not None:
|
if self.xpos is not None:
|
||||||
|
@ -117,31 +113,51 @@ class MultiheadAttention(nn.Module):
|
||||||
offset = src_len - 1
|
offset = src_len - 1
|
||||||
else:
|
else:
|
||||||
offset = 0
|
offset = 0
|
||||||
|
k, q = map(lambda t: t.view(bsz * self.num_heads, -1, self.head_dim), (k, q))
|
||||||
k = self.xpos(k, offset=0, downscale=True)
|
k = self.xpos(k, offset=0, downscale=True)
|
||||||
q = self.xpos(q, offset=offset, downscale=False)
|
q = self.xpos(q, offset=offset, downscale=False)
|
||||||
|
k, q = map(lambda t: t.view(bsz, self.num_heads, -1, self.head_dim), (k, q))
|
||||||
|
|
||||||
attn_weights = torch.bmm(q, k.transpose(1, 2))
|
if attn_mask is not None and attn_mask.ndim != 4:
|
||||||
|
# Add batch and heads
|
||||||
if attn_mask is not None:
|
attn_mask = attn_mask.reshape(1, 1, *attn_mask.shape).expand(bsz, self.num_heads, -1, -1)
|
||||||
attn_weights = torch.nan_to_num(attn_weights)
|
# else:
|
||||||
attn_mask = attn_mask.unsqueeze(0)
|
# attn_mask = torch.zeros(1, tgt_len, src_len, dtype=torch.float32, device=k.device)
|
||||||
attn_weights += attn_mask
|
|
||||||
|
|
||||||
if key_padding_mask is not None:
|
if key_padding_mask is not None:
|
||||||
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
# Achieve same result with an additive mask
|
||||||
attn_weights = attn_weights.masked_fill(
|
key_padding_mask = torch.where(key_padding_mask, float("-inf"), 0.0)
|
||||||
key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
|
# Add heads and dst_len
|
||||||
float("-inf"),
|
key_padding_mask = key_padding_mask.reshape(bsz, 1, 1, src_len).to(q.dtype).expand(-1, self.num_heads, tgt_len, -1)
|
||||||
)
|
if attn_mask is not None:
|
||||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
attn_mask = attn_mask + key_padding_mask
|
||||||
|
else:
|
||||||
|
attn_mask = key_padding_mask.expand(-1, self.num_heads, tgt_len, -1)
|
||||||
|
|
||||||
if rel_pos is not None:
|
if rel_pos is not None:
|
||||||
rel_pos = rel_pos.view(attn_weights.size())
|
if attn_mask is not None:
|
||||||
attn_weights = attn_weights + rel_pos
|
attn_mask = attn_mask + rel_pos.view(attn_mask.size())
|
||||||
|
else:
|
||||||
|
attn_mask = rel_pos.reshape(bsz, self.num_heads, tgt_len, src_len)
|
||||||
|
|
||||||
|
if hasattr(F, "scaled_dot_product_attention"):
|
||||||
|
attn = F.scaled_dot_product_attention(
|
||||||
|
q, k, v, attn_mask, self.dropout_module.p
|
||||||
|
)
|
||||||
|
# attn: B,H,T,E (Batch, Heads, Tgt_Len, Dim)
|
||||||
|
# Permute to B,T,H,E, and then flatten to B,T,D
|
||||||
|
attn = attn.permute(0, 2, 1, 3).flatten(2)
|
||||||
|
attn_weights = None
|
||||||
|
else:
|
||||||
|
q *= self.scaling
|
||||||
|
q, k, v = map(lambda t: t.view(bsz * self.num_heads, -1, self.head_dim), (q, k, v))
|
||||||
|
attn_weights = torch.bmm(q, k.transpose(1, 2))
|
||||||
|
|
||||||
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).type_as(
|
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).type_as(
|
||||||
attn_weights
|
attn_weights
|
||||||
)
|
)
|
||||||
|
attn_weights = attn_weights.view(
|
||||||
|
bsz, self.num_heads, tgt_len, src_len).transpose(1, 0)
|
||||||
attn_probs = self.dropout_module(attn_weights)
|
attn_probs = self.dropout_module(attn_weights)
|
||||||
|
|
||||||
attn = torch.bmm(attn_probs, v)
|
attn = torch.bmm(attn_probs, v)
|
||||||
|
@ -151,8 +167,4 @@ class MultiheadAttention(nn.Module):
|
||||||
attn = self.inner_attn_ln(attn)
|
attn = self.inner_attn_ln(attn)
|
||||||
|
|
||||||
attn = self.out_proj(attn)
|
attn = self.out_proj(attn)
|
||||||
attn_weights = attn_weights.view(
|
|
||||||
bsz, self.num_heads, tgt_len, src_len
|
|
||||||
).transpose(1, 0)
|
|
||||||
|
|
||||||
return attn, attn_weights
|
return attn, attn_weights
|
||||||
|
|
Loading…
Reference in New Issue
Block a user