Merge pull request #1 from mranzinger/efficient
Oh I must have overlooked that
This commit is contained in:
commit
dd69dcb5e9
|
@ -339,23 +339,13 @@ class Encoder(nn.Module):
|
|||
):
|
||||
assert src_tokens is not None or token_embeddings is not None
|
||||
|
||||
if encoder_padding_mask is None:
|
||||
if src_tokens is not None:
|
||||
encoder_padding_mask = torch.zeros_like(
|
||||
src_tokens, device=src_tokens.device
|
||||
).bool()
|
||||
else:
|
||||
encoder_padding_mask = torch.zeros(
|
||||
[token_embeddings.size(0), token_embeddings.size(1)],
|
||||
device=token_embeddings.device,
|
||||
).bool()
|
||||
|
||||
if multiway_split_position is not None:
|
||||
assert self.args.multiway
|
||||
self.apply(set_split_position(multiway_split_position))
|
||||
|
||||
x, encoder_embedding = self.forward_embedding(src_tokens, token_embeddings, positions)
|
||||
x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x))
|
||||
if encoder_padding_mask is not None:
|
||||
x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x))
|
||||
|
||||
encoder_states = []
|
||||
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
# Licensed under The MIT License [see LICENSE for details]
|
||||
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
@ -64,12 +65,12 @@ class MultiheadAttention(nn.Module):
|
|||
|
||||
def forward(
|
||||
self,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
incremental_state=None,
|
||||
key_padding_mask=None,
|
||||
attn_mask=None,
|
||||
key_padding_mask: Optional[torch.Tensor] = None,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
rel_pos=None,
|
||||
):
|
||||
bsz, tgt_len, embed_dim = query.size()
|
||||
|
@ -116,20 +117,35 @@ class MultiheadAttention(nn.Module):
|
|||
q = self.xpos(q, offset=offset, downscale=False)
|
||||
k, q = map(lambda t: t.view(bsz, self.num_heads, -1, self.head_dim), (k, q))
|
||||
|
||||
if attn_mask is not None:
|
||||
attn_mask = attn_mask.unsqueeze(0)
|
||||
if attn_mask is not None and attn_mask.ndim != 4:
|
||||
# Add batch and heads
|
||||
attn_mask = attn_mask.reshape(1, 1, *attn_mask.shape).expand(bsz, self.num_heads, -1, -1)
|
||||
# else:
|
||||
# attn_mask = torch.zeros(1, tgt_len, src_len, dtype=torch.float32, device=k.device)
|
||||
|
||||
if key_padding_mask is not None:
|
||||
# Achieve same result with an additive mask
|
||||
attn_mask += key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.float32) * float("-inf")
|
||||
key_padding_mask = torch.where(key_padding_mask, float("-inf"), 0.0)
|
||||
# Add heads and dst_len
|
||||
key_padding_mask = key_padding_mask.reshape(bsz, 1, 1, src_len).to(q.dtype).expand(-1, self.num_heads, tgt_len, -1)
|
||||
if attn_mask is not None:
|
||||
attn_mask = attn_mask + key_padding_mask
|
||||
else:
|
||||
attn_mask = key_padding_mask.expand(-1, self.num_heads, tgt_len, -1)
|
||||
|
||||
if rel_pos is not None:
|
||||
attn_mask += rel_pos.view(attn_mask.size())
|
||||
if attn_mask is not None:
|
||||
attn_mask = attn_mask + rel_pos.view(attn_mask.size())
|
||||
else:
|
||||
attn_mask = rel_pos.reshape(bsz, self.num_heads, tgt_len, src_len)
|
||||
|
||||
if hasattr(F, "scaled_dot_product_attention"):
|
||||
attn = F.scaled_dot_product_attention(
|
||||
q, k, v, attn_mask, self.dropout_module.p
|
||||
)
|
||||
# attn: B,H,T,E (Batch, Heads, Tgt_Len, Dim)
|
||||
# Permute to B,T,H,E, and then flatten to B,T,D
|
||||
attn = attn.permute(0, 2, 1, 3).flatten(2)
|
||||
attn_weights = None
|
||||
else:
|
||||
q *= self.scaling
|
||||
|
|
Loading…
Reference in New Issue
Block a user