diff --git a/torchscale/architecture/decoder.py b/torchscale/architecture/decoder.py index 5af981e..ed407b0 100644 --- a/torchscale/architecture/decoder.py +++ b/torchscale/architecture/decoder.py @@ -140,6 +140,7 @@ class DecoderLayer(nn.Module): self_attn_padding_mask=None, self_attn_rel_pos=None, cross_attn_rel_pos=None, + is_first_step=False, ): residual = x if self.normalize_before: @@ -153,6 +154,7 @@ class DecoderLayer(nn.Module): incremental_state=incremental_state, attn_mask=self_attn_mask, rel_pos=self_attn_rel_pos, + is_first_step=is_first_step, ) x = self.dropout_module(x) @@ -357,7 +359,7 @@ class Decoder(nn.Module): tokens, incremental_state=incremental_state ) - if incremental_state is not None: + if incremental_state is not None and not self.is_first_step(incremental_state): tokens = tokens[:, -1:] if positions is not None: positions = positions[:, -1:] @@ -377,6 +379,11 @@ class Decoder(nn.Module): return x, embed + def is_first_step(self, incremental_state): + if incremental_state is None: + return False + return incremental_state.get("is_first_step", False) + def forward( self, prev_output_tokens, @@ -392,6 +399,7 @@ class Decoder(nn.Module): x, _ = self.forward_embedding( prev_output_tokens, token_embeddings, incremental_state ) + is_first_step = self.is_first_step(incremental_state) # relative position self_attn_rel_pos_bias = None @@ -400,7 +408,7 @@ class Decoder(nn.Module): self_attn_rel_pos_bias = self.self_attn_relative_position( batch_size=x.size(0), qlen=slen, klen=slen ) - if incremental_state is not None: + if incremental_state is not None and not is_first_step: self_attn_rel_pos_bias = self_attn_rel_pos_bias[-1:, :, :] cross_attn_rel_pos_bias = None if self.cross_attn_relative_position is not None: @@ -409,7 +417,7 @@ class Decoder(nn.Module): qlen=slen, klen=encoder_out["encoder_out"].size(1), ) - if incremental_state is not None: + if incremental_state is not None and not is_first_step: cross_attn_rel_pos_bias = cross_attn_rel_pos_bias[-1:, :, :] # decoder layers @@ -421,7 +429,7 @@ class Decoder(nn.Module): l_aux = encoder_out["l_aux"] if "l_aux" in encoder_out else [] for idx, layer in enumerate(self.layers): - if incremental_state is None: + if incremental_state is None or is_first_step: self_attn_mask = torch.triu( torch.zeros([x.size(1), x.size(1)]) .float() @@ -429,6 +437,9 @@ class Decoder(nn.Module): .type_as(x), 1, ) + if is_first_step and incremental_state is not None: + if idx not in incremental_state: + incremental_state[idx] = {} else: self_attn_mask = None if idx not in incremental_state: @@ -445,6 +456,7 @@ class Decoder(nn.Module): self_attn_padding_mask=self_attn_padding_mask, self_attn_rel_pos=self_attn_rel_pos_bias, cross_attn_rel_pos=cross_attn_rel_pos_bias, + is_first_step=is_first_step, ) l_aux.append(l_aux_i) inner_states.append(x) diff --git a/torchscale/component/multihead_attention.py b/torchscale/component/multihead_attention.py index 392d0e9..33e917e 100644 --- a/torchscale/component/multihead_attention.py +++ b/torchscale/component/multihead_attention.py @@ -71,6 +71,7 @@ class MultiheadAttention(nn.Module): key_padding_mask=None, attn_mask=None, rel_pos=None, + is_first_step=False, ): bsz, tgt_len, embed_dim = query.size() src_len = tgt_len @@ -112,7 +113,7 @@ class MultiheadAttention(nn.Module): src_len = k.size(1) if self.xpos is not None: - if incremental_state is not None: + if incremental_state is not None and not is_first_step: offset = src_len - 1 else: offset = 0