support lm prefix computation in one go
This commit is contained in:
parent
b59b6f87b9
commit
c766630327
|
@ -140,6 +140,7 @@ class DecoderLayer(nn.Module):
|
|||
self_attn_padding_mask=None,
|
||||
self_attn_rel_pos=None,
|
||||
cross_attn_rel_pos=None,
|
||||
is_first_step=False,
|
||||
):
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
|
@ -153,6 +154,7 @@ class DecoderLayer(nn.Module):
|
|||
incremental_state=incremental_state,
|
||||
attn_mask=self_attn_mask,
|
||||
rel_pos=self_attn_rel_pos,
|
||||
is_first_step=is_first_step,
|
||||
)
|
||||
x = self.dropout_module(x)
|
||||
|
||||
|
@ -357,7 +359,7 @@ class Decoder(nn.Module):
|
|||
tokens, incremental_state=incremental_state
|
||||
)
|
||||
|
||||
if incremental_state is not None:
|
||||
if incremental_state is not None and not self.is_first_step(incremental_state):
|
||||
tokens = tokens[:, -1:]
|
||||
if positions is not None:
|
||||
positions = positions[:, -1:]
|
||||
|
@ -377,6 +379,11 @@ class Decoder(nn.Module):
|
|||
|
||||
return x, embed
|
||||
|
||||
def is_first_step(self, incremental_state):
|
||||
if incremental_state is None:
|
||||
return False
|
||||
return incremental_state.get("is_first_step", False)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
prev_output_tokens,
|
||||
|
@ -392,6 +399,7 @@ class Decoder(nn.Module):
|
|||
x, _ = self.forward_embedding(
|
||||
prev_output_tokens, token_embeddings, incremental_state
|
||||
)
|
||||
is_first_step = self.is_first_step(incremental_state)
|
||||
|
||||
# relative position
|
||||
self_attn_rel_pos_bias = None
|
||||
|
@ -400,7 +408,7 @@ class Decoder(nn.Module):
|
|||
self_attn_rel_pos_bias = self.self_attn_relative_position(
|
||||
batch_size=x.size(0), qlen=slen, klen=slen
|
||||
)
|
||||
if incremental_state is not None:
|
||||
if incremental_state is not None and not is_first_step:
|
||||
self_attn_rel_pos_bias = self_attn_rel_pos_bias[-1:, :, :]
|
||||
cross_attn_rel_pos_bias = None
|
||||
if self.cross_attn_relative_position is not None:
|
||||
|
@ -409,7 +417,7 @@ class Decoder(nn.Module):
|
|||
qlen=slen,
|
||||
klen=encoder_out["encoder_out"].size(1),
|
||||
)
|
||||
if incremental_state is not None:
|
||||
if incremental_state is not None and not is_first_step:
|
||||
cross_attn_rel_pos_bias = cross_attn_rel_pos_bias[-1:, :, :]
|
||||
|
||||
# decoder layers
|
||||
|
@ -421,7 +429,7 @@ class Decoder(nn.Module):
|
|||
l_aux = encoder_out["l_aux"] if "l_aux" in encoder_out else []
|
||||
|
||||
for idx, layer in enumerate(self.layers):
|
||||
if incremental_state is None:
|
||||
if incremental_state is None or is_first_step:
|
||||
self_attn_mask = torch.triu(
|
||||
torch.zeros([x.size(1), x.size(1)])
|
||||
.float()
|
||||
|
@ -429,6 +437,9 @@ class Decoder(nn.Module):
|
|||
.type_as(x),
|
||||
1,
|
||||
)
|
||||
if is_first_step and incremental_state is not None:
|
||||
if idx not in incremental_state:
|
||||
incremental_state[idx] = {}
|
||||
else:
|
||||
self_attn_mask = None
|
||||
if idx not in incremental_state:
|
||||
|
@ -445,6 +456,7 @@ class Decoder(nn.Module):
|
|||
self_attn_padding_mask=self_attn_padding_mask,
|
||||
self_attn_rel_pos=self_attn_rel_pos_bias,
|
||||
cross_attn_rel_pos=cross_attn_rel_pos_bias,
|
||||
is_first_step=is_first_step,
|
||||
)
|
||||
l_aux.append(l_aux_i)
|
||||
inner_states.append(x)
|
||||
|
|
|
@ -71,6 +71,7 @@ class MultiheadAttention(nn.Module):
|
|||
key_padding_mask=None,
|
||||
attn_mask=None,
|
||||
rel_pos=None,
|
||||
is_first_step=False,
|
||||
):
|
||||
bsz, tgt_len, embed_dim = query.size()
|
||||
src_len = tgt_len
|
||||
|
@ -112,7 +113,7 @@ class MultiheadAttention(nn.Module):
|
|||
src_len = k.size(1)
|
||||
|
||||
if self.xpos is not None:
|
||||
if incremental_state is not None:
|
||||
if incremental_state is not None and not is_first_step:
|
||||
offset = src_len - 1
|
||||
else:
|
||||
offset = 0
|
||||
|
|
Loading…
Reference in New Issue
Block a user