support lm prefix computation in one go

This commit is contained in:
Xingxing Zhang 2023-06-03 15:37:47 +00:00
parent b59b6f87b9
commit c766630327
2 changed files with 18 additions and 5 deletions

View File

@ -140,6 +140,7 @@ class DecoderLayer(nn.Module):
self_attn_padding_mask=None,
self_attn_rel_pos=None,
cross_attn_rel_pos=None,
is_first_step=False,
):
residual = x
if self.normalize_before:
@ -153,6 +154,7 @@ class DecoderLayer(nn.Module):
incremental_state=incremental_state,
attn_mask=self_attn_mask,
rel_pos=self_attn_rel_pos,
is_first_step=is_first_step,
)
x = self.dropout_module(x)
@ -357,7 +359,7 @@ class Decoder(nn.Module):
tokens, incremental_state=incremental_state
)
if incremental_state is not None:
if incremental_state is not None and not self.is_first_step(incremental_state):
tokens = tokens[:, -1:]
if positions is not None:
positions = positions[:, -1:]
@ -377,6 +379,11 @@ class Decoder(nn.Module):
return x, embed
def is_first_step(self, incremental_state):
if incremental_state is None:
return False
return incremental_state.get("is_first_step", False)
def forward(
self,
prev_output_tokens,
@ -392,6 +399,7 @@ class Decoder(nn.Module):
x, _ = self.forward_embedding(
prev_output_tokens, token_embeddings, incremental_state
)
is_first_step = self.is_first_step(incremental_state)
# relative position
self_attn_rel_pos_bias = None
@ -400,7 +408,7 @@ class Decoder(nn.Module):
self_attn_rel_pos_bias = self.self_attn_relative_position(
batch_size=x.size(0), qlen=slen, klen=slen
)
if incremental_state is not None:
if incremental_state is not None and not is_first_step:
self_attn_rel_pos_bias = self_attn_rel_pos_bias[-1:, :, :]
cross_attn_rel_pos_bias = None
if self.cross_attn_relative_position is not None:
@ -409,7 +417,7 @@ class Decoder(nn.Module):
qlen=slen,
klen=encoder_out["encoder_out"].size(1),
)
if incremental_state is not None:
if incremental_state is not None and not is_first_step:
cross_attn_rel_pos_bias = cross_attn_rel_pos_bias[-1:, :, :]
# decoder layers
@ -421,7 +429,7 @@ class Decoder(nn.Module):
l_aux = encoder_out["l_aux"] if "l_aux" in encoder_out else []
for idx, layer in enumerate(self.layers):
if incremental_state is None:
if incremental_state is None or is_first_step:
self_attn_mask = torch.triu(
torch.zeros([x.size(1), x.size(1)])
.float()
@ -429,6 +437,9 @@ class Decoder(nn.Module):
.type_as(x),
1,
)
if is_first_step and incremental_state is not None:
if idx not in incremental_state:
incremental_state[idx] = {}
else:
self_attn_mask = None
if idx not in incremental_state:
@ -445,6 +456,7 @@ class Decoder(nn.Module):
self_attn_padding_mask=self_attn_padding_mask,
self_attn_rel_pos=self_attn_rel_pos_bias,
cross_attn_rel_pos=cross_attn_rel_pos_bias,
is_first_step=is_first_step,
)
l_aux.append(l_aux_i)
inner_states.append(x)

View File

@ -71,6 +71,7 @@ class MultiheadAttention(nn.Module):
key_padding_mask=None,
attn_mask=None,
rel_pos=None,
is_first_step=False,
):
bsz, tgt_len, embed_dim = query.size()
src_len = tgt_len
@ -112,7 +113,7 @@ class MultiheadAttention(nn.Module):
src_len = k.size(1)
if self.xpos is not None:
if incremental_state is not None:
if incremental_state is not None and not is_first_step:
offset = src_len - 1
else:
offset = 0