158 lines
5.5 KiB
Python
158 lines
5.5 KiB
Python
# Copyright (c) 2022 Microsoft
|
|
# Licensed under The MIT License [see LICENSE for details]
|
|
|
|
import math
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from torch import nn
|
|
try:
|
|
from apex.normalization import FusedLayerNorm as LayerNorm
|
|
except ModuleNotFoundError:
|
|
from torch.nn import LayerNorm
|
|
|
|
from .multiway_network import MultiwayWrapper
|
|
from .xpos_relative_position import XPOS
|
|
|
|
|
|
class MultiheadAttention(nn.Module):
|
|
def __init__(
|
|
self,
|
|
args,
|
|
embed_dim,
|
|
num_heads,
|
|
dropout=0.0,
|
|
self_attention=False,
|
|
encoder_decoder_attention=False,
|
|
subln=False,
|
|
):
|
|
super().__init__()
|
|
self.args = args
|
|
self.embed_dim = embed_dim
|
|
self.num_heads = num_heads
|
|
self.head_dim = embed_dim // num_heads
|
|
self.scaling = self.head_dim**-0.5
|
|
|
|
self.self_attention = self_attention
|
|
self.encoder_decoder_attention = encoder_decoder_attention
|
|
assert self.self_attention ^ self.encoder_decoder_attention
|
|
|
|
self.k_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=True))
|
|
self.v_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=True))
|
|
self.q_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=True))
|
|
self.out_proj = MultiwayWrapper(
|
|
args, nn.Linear(embed_dim, embed_dim, bias=True)
|
|
)
|
|
self.inner_attn_ln = (
|
|
MultiwayWrapper(args, LayerNorm(self.embed_dim, eps=args.layernorm_eps))
|
|
if subln and self.self_attention
|
|
else None
|
|
)
|
|
self.dropout_module = torch.nn.Dropout(dropout)
|
|
self.xpos = (
|
|
XPOS(self.head_dim, args.xpos_scale_base)
|
|
if args.xpos_rel_pos and self.self_attention
|
|
else None
|
|
)
|
|
|
|
def reset_parameters(self):
|
|
nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
|
|
nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
|
|
nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
|
|
nn.init.xavier_uniform_(self.out_proj.weight)
|
|
nn.init.constant_(self.out_proj.bias, 0.0)
|
|
|
|
def forward(
|
|
self,
|
|
query,
|
|
key,
|
|
value,
|
|
incremental_state=None,
|
|
key_padding_mask=None,
|
|
attn_mask=None,
|
|
rel_pos=None,
|
|
):
|
|
bsz, tgt_len, embed_dim = query.size()
|
|
src_len = tgt_len
|
|
assert embed_dim == self.embed_dim, f"query dim {embed_dim} != {self.embed_dim}"
|
|
|
|
key_bsz, src_len, _ = key.size()
|
|
assert key_bsz == bsz, f"{query.size(), key.size()}"
|
|
assert value is not None
|
|
assert bsz, src_len == value.shape[:2]
|
|
|
|
q = self.q_proj(query)
|
|
k = self.k_proj(key)
|
|
v = self.v_proj(value)
|
|
q *= self.scaling
|
|
|
|
q = q.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
k = k.view(bsz, src_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
v = v.view(bsz, src_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
q = q.reshape(bsz * self.num_heads, tgt_len, self.head_dim)
|
|
k = k.reshape(bsz * self.num_heads, src_len, self.head_dim)
|
|
v = v.reshape(bsz * self.num_heads, src_len, self.head_dim)
|
|
|
|
if incremental_state is not None:
|
|
if "prev_key" in incremental_state:
|
|
prev_key = incremental_state["prev_key"].view(
|
|
bsz * self.num_heads, -1, self.head_dim
|
|
)
|
|
prev_value = incremental_state["prev_value"].view(
|
|
bsz * self.num_heads, -1, self.head_dim
|
|
)
|
|
k = torch.cat([prev_key, k], dim=1)
|
|
v = torch.cat([prev_value, v], dim=1)
|
|
incremental_state["prev_key"] = k.view(
|
|
bsz, self.num_heads, -1, self.head_dim
|
|
)
|
|
incremental_state["prev_value"] = v.view(
|
|
bsz, self.num_heads, -1, self.head_dim
|
|
)
|
|
src_len = k.size(1)
|
|
|
|
if self.xpos is not None:
|
|
if incremental_state is not None:
|
|
offset = src_len - 1
|
|
else:
|
|
offset = 0
|
|
k = self.xpos(k, offset=0, downscale=True)
|
|
q = self.xpos(q, offset=offset, downscale=False)
|
|
|
|
attn_weights = torch.bmm(q, k.transpose(1, 2))
|
|
|
|
if attn_mask is not None:
|
|
attn_weights = torch.nan_to_num(attn_weights)
|
|
attn_mask = attn_mask.unsqueeze(0)
|
|
attn_weights += attn_mask
|
|
|
|
if key_padding_mask is not None:
|
|
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
|
attn_weights = attn_weights.masked_fill(
|
|
key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
|
|
float("-inf"),
|
|
)
|
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
|
|
|
if rel_pos is not None:
|
|
rel_pos = rel_pos.view(attn_weights.size())
|
|
attn_weights = attn_weights + rel_pos
|
|
|
|
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).type_as(
|
|
attn_weights
|
|
)
|
|
attn_probs = self.dropout_module(attn_weights)
|
|
|
|
attn = torch.bmm(attn_probs, v)
|
|
attn = attn.transpose(0, 1).reshape(tgt_len, bsz, embed_dim).transpose(0, 1)
|
|
|
|
if self.inner_attn_ln is not None:
|
|
attn = self.inner_attn_ln(attn)
|
|
|
|
attn = self.out_proj(attn)
|
|
attn_weights = attn_weights.view(
|
|
bsz, self.num_heads, tgt_len, src_len
|
|
).transpose(1, 0)
|
|
|
|
return attn, attn_weights
|