Merge pull request #33 from XingxingZhang/lm_sampling
support lm prefix computation in one go
This commit is contained in:
commit
2b101355d7
|
@ -140,6 +140,7 @@ class DecoderLayer(nn.Module):
|
||||||
self_attn_padding_mask=None,
|
self_attn_padding_mask=None,
|
||||||
self_attn_rel_pos=None,
|
self_attn_rel_pos=None,
|
||||||
cross_attn_rel_pos=None,
|
cross_attn_rel_pos=None,
|
||||||
|
is_first_step=False,
|
||||||
):
|
):
|
||||||
residual = x
|
residual = x
|
||||||
if self.normalize_before:
|
if self.normalize_before:
|
||||||
|
@ -153,6 +154,7 @@ class DecoderLayer(nn.Module):
|
||||||
incremental_state=incremental_state,
|
incremental_state=incremental_state,
|
||||||
attn_mask=self_attn_mask,
|
attn_mask=self_attn_mask,
|
||||||
rel_pos=self_attn_rel_pos,
|
rel_pos=self_attn_rel_pos,
|
||||||
|
is_first_step=is_first_step,
|
||||||
)
|
)
|
||||||
x = self.dropout_module(x)
|
x = self.dropout_module(x)
|
||||||
|
|
||||||
|
@ -357,7 +359,7 @@ class Decoder(nn.Module):
|
||||||
tokens, incremental_state=incremental_state
|
tokens, incremental_state=incremental_state
|
||||||
)
|
)
|
||||||
|
|
||||||
if incremental_state is not None:
|
if incremental_state is not None and not self.is_first_step(incremental_state):
|
||||||
tokens = tokens[:, -1:]
|
tokens = tokens[:, -1:]
|
||||||
if positions is not None:
|
if positions is not None:
|
||||||
positions = positions[:, -1:]
|
positions = positions[:, -1:]
|
||||||
|
@ -377,6 +379,11 @@ class Decoder(nn.Module):
|
||||||
|
|
||||||
return x, embed
|
return x, embed
|
||||||
|
|
||||||
|
def is_first_step(self, incremental_state):
|
||||||
|
if incremental_state is None:
|
||||||
|
return False
|
||||||
|
return incremental_state.get("is_first_step", False)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
prev_output_tokens,
|
prev_output_tokens,
|
||||||
|
@ -392,6 +399,7 @@ class Decoder(nn.Module):
|
||||||
x, _ = self.forward_embedding(
|
x, _ = self.forward_embedding(
|
||||||
prev_output_tokens, token_embeddings, incremental_state
|
prev_output_tokens, token_embeddings, incremental_state
|
||||||
)
|
)
|
||||||
|
is_first_step = self.is_first_step(incremental_state)
|
||||||
|
|
||||||
# relative position
|
# relative position
|
||||||
self_attn_rel_pos_bias = None
|
self_attn_rel_pos_bias = None
|
||||||
|
@ -400,7 +408,7 @@ class Decoder(nn.Module):
|
||||||
self_attn_rel_pos_bias = self.self_attn_relative_position(
|
self_attn_rel_pos_bias = self.self_attn_relative_position(
|
||||||
batch_size=x.size(0), qlen=slen, klen=slen
|
batch_size=x.size(0), qlen=slen, klen=slen
|
||||||
)
|
)
|
||||||
if incremental_state is not None:
|
if incremental_state is not None and not is_first_step:
|
||||||
self_attn_rel_pos_bias = self_attn_rel_pos_bias[-1:, :, :]
|
self_attn_rel_pos_bias = self_attn_rel_pos_bias[-1:, :, :]
|
||||||
cross_attn_rel_pos_bias = None
|
cross_attn_rel_pos_bias = None
|
||||||
if self.cross_attn_relative_position is not None:
|
if self.cross_attn_relative_position is not None:
|
||||||
|
@ -409,7 +417,7 @@ class Decoder(nn.Module):
|
||||||
qlen=slen,
|
qlen=slen,
|
||||||
klen=encoder_out["encoder_out"].size(1),
|
klen=encoder_out["encoder_out"].size(1),
|
||||||
)
|
)
|
||||||
if incremental_state is not None:
|
if incremental_state is not None and not is_first_step:
|
||||||
cross_attn_rel_pos_bias = cross_attn_rel_pos_bias[-1:, :, :]
|
cross_attn_rel_pos_bias = cross_attn_rel_pos_bias[-1:, :, :]
|
||||||
|
|
||||||
# decoder layers
|
# decoder layers
|
||||||
|
@ -421,7 +429,7 @@ class Decoder(nn.Module):
|
||||||
l_aux = encoder_out["l_aux"] if "l_aux" in encoder_out else []
|
l_aux = encoder_out["l_aux"] if "l_aux" in encoder_out else []
|
||||||
|
|
||||||
for idx, layer in enumerate(self.layers):
|
for idx, layer in enumerate(self.layers):
|
||||||
if incremental_state is None:
|
if incremental_state is None or is_first_step:
|
||||||
self_attn_mask = torch.triu(
|
self_attn_mask = torch.triu(
|
||||||
torch.zeros([x.size(1), x.size(1)])
|
torch.zeros([x.size(1), x.size(1)])
|
||||||
.float()
|
.float()
|
||||||
|
@ -429,6 +437,9 @@ class Decoder(nn.Module):
|
||||||
.type_as(x),
|
.type_as(x),
|
||||||
1,
|
1,
|
||||||
)
|
)
|
||||||
|
if is_first_step and incremental_state is not None:
|
||||||
|
if idx not in incremental_state:
|
||||||
|
incremental_state[idx] = {}
|
||||||
else:
|
else:
|
||||||
self_attn_mask = None
|
self_attn_mask = None
|
||||||
if idx not in incremental_state:
|
if idx not in incremental_state:
|
||||||
|
@ -445,6 +456,7 @@ class Decoder(nn.Module):
|
||||||
self_attn_padding_mask=self_attn_padding_mask,
|
self_attn_padding_mask=self_attn_padding_mask,
|
||||||
self_attn_rel_pos=self_attn_rel_pos_bias,
|
self_attn_rel_pos=self_attn_rel_pos_bias,
|
||||||
cross_attn_rel_pos=cross_attn_rel_pos_bias,
|
cross_attn_rel_pos=cross_attn_rel_pos_bias,
|
||||||
|
is_first_step=is_first_step,
|
||||||
)
|
)
|
||||||
l_aux.append(l_aux_i)
|
l_aux.append(l_aux_i)
|
||||||
inner_states.append(x)
|
inner_states.append(x)
|
||||||
|
|
|
@ -71,6 +71,7 @@ class MultiheadAttention(nn.Module):
|
||||||
key_padding_mask=None,
|
key_padding_mask=None,
|
||||||
attn_mask=None,
|
attn_mask=None,
|
||||||
rel_pos=None,
|
rel_pos=None,
|
||||||
|
is_first_step=False,
|
||||||
):
|
):
|
||||||
bsz, tgt_len, embed_dim = query.size()
|
bsz, tgt_len, embed_dim = query.size()
|
||||||
src_len = tgt_len
|
src_len = tgt_len
|
||||||
|
@ -112,7 +113,7 @@ class MultiheadAttention(nn.Module):
|
||||||
src_len = k.size(1)
|
src_len = k.size(1)
|
||||||
|
|
||||||
if self.xpos is not None:
|
if self.xpos is not None:
|
||||||
if incremental_state is not None:
|
if incremental_state is not None and not is_first_step:
|
||||||
offset = src_len - 1
|
offset = src_len - 1
|
||||||
else:
|
else:
|
||||||
offset = 0
|
offset = 0
|
||||||
|
|
Loading…
Reference in New Issue
Block a user