b3 incremental decoding

This commit is contained in:
Wenhui Wang 2023-03-09 12:02:36 +08:00
parent 891f84f302
commit 599df73687
5 changed files with 49 additions and 11 deletions

View File

@ -9,6 +9,7 @@ class EncoderConfig(object):
self.encoder_ffn_embed_dim = kwargs.pop("encoder_ffn_embed_dim", 3072)
self.encoder_layers = kwargs.pop("encoder_layers", 12)
self.encoder_normalize_before = kwargs.pop("encoder_normalize_before", True)
self.normalize_output = kwargs.pop("normalize_output", True)
self.activation_fn = kwargs.pop("activation_fn", "gelu")
self.dropout = kwargs.pop("dropout", 0.0)
self.drop_path_rate = kwargs.pop("drop_path_rate", 0.0)

View File

@ -113,7 +113,7 @@ class EncoderLayer(nn.Module):
def residual_connection(self, x, residual):
return residual * self.alpha + x
def forward(self, x, encoder_padding_mask, attn_mask=None, rel_pos=None, multiway_split_position=None):
def forward(self, x, encoder_padding_mask, attn_mask=None, rel_pos=None, multiway_split_position=None, incremental_state=None):
if multiway_split_position is not None:
assert self.args.multiway
self.apply(set_split_position(multiway_split_position))
@ -131,6 +131,7 @@ class EncoderLayer(nn.Module):
key_padding_mask=encoder_padding_mask,
attn_mask=attn_mask,
rel_pos=rel_pos,
incremental_state=incremental_state,
)
x = self.dropout_module(x)
@ -214,7 +215,7 @@ class Encoder(nn.Module):
)
self.num_layers = len(self.layers)
if args.encoder_normalize_before:
if args.encoder_normalize_before and args.normalize_output:
self.layer_norm = MultiwayWrapper(args, LayerNorm(embed_dim, eps=args.layernorm_eps))
else:
self.layer_norm = None
@ -308,15 +309,16 @@ class Encoder(nn.Module):
self,
src_tokens,
token_embedding=None,
positions=None,
):
if token_embedding is None:
token_embedding = self.embed_tokens(src_tokens)
x = embed = self.embed_scale * token_embedding
if self.embed_positions is not None:
if src_tokens is not None:
x = embed + self.embed_positions(src_tokens)
x = embed + self.embed_positions(src_tokens, positions=positions)
else:
x = embed + self.embed_positions(x)
x = embed + self.embed_positions(x, positions=positions)
if self.layernorm_embedding is not None:
x = self.layernorm_embedding(x)
x = self.dropout_module(x)
@ -326,10 +328,13 @@ class Encoder(nn.Module):
self,
src_tokens,
encoder_padding_mask=None,
attn_mask=None,
return_all_hiddens=False,
token_embeddings=None,
multiway_split_position=None,
features_only=False,
incremental_state=None,
positions=None,
**kwargs
):
assert src_tokens is not None or token_embeddings is not None
@ -349,7 +354,7 @@ class Encoder(nn.Module):
assert self.args.multiway
self.apply(set_split_position(multiway_split_position))
x, encoder_embedding = self.forward_embedding(src_tokens, token_embeddings)
x, encoder_embedding = self.forward_embedding(src_tokens, token_embeddings, positions)
x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x))
encoder_states = []
@ -363,10 +368,16 @@ class Encoder(nn.Module):
batch_size=x.size(0), qlen=x.size(1), klen=x.size(1)
)
# incremental_state is not None during inference if we use the bidirectional encoder as a generator as in s2s-ft (https://arxiv.org/abs/2110.13640)
l_aux = []
for layer in self.layers:
for idx, layer in enumerate(self.layers):
x, l_aux_i = layer(
x, encoder_padding_mask=encoder_padding_mask, rel_pos=rel_pos_bias, multiway_split_position=multiway_split_position
x,
encoder_padding_mask=encoder_padding_mask if incremental_state is None else None,
attn_mask=attn_mask,
rel_pos=rel_pos_bias,
multiway_split_position=multiway_split_position,
incremental_state=incremental_state[idx] if incremental_state is not None else None,
)
if return_all_hiddens:
assert encoder_states is not None

View File

@ -60,6 +60,12 @@ class VisionEmbedding(nn.Module):
else:
self.cls_token = None
def num_position_embeddings(self):
if self.cls_token is None:
return self.num_patches
else:
return self.num_patches + 1
def forward(self, x, masked_position=None, **kwargs):
B, C, H, W = x.shape
assert (

View File

@ -43,3 +43,13 @@ class MultiwayNetwork(nn.Module):
# x1, x2 = x[:self.split_position], x[self.split_position:]
y1, y2 = self.A(x1, **kwargs), self.B(x2, **kwargs)
return torch.cat([y1, y2], dim=self.dim)
class MutliwayEmbedding(MultiwayNetwork):
def __init__(self, modules, dim=1):
super(MultiwayNetwork, self).__init__()
self.dim = dim
assert len(modules) == 2
self.A = modules[0]
self.B = modules[1]
self.split_position = -1

View File

@ -10,7 +10,7 @@ from torchscale.component.embedding import (
TextEmbedding,
VisionEmbedding,
)
from torchscale.component.multiway_network import MultiwayWrapper
from torchscale.component.multiway_network import MutliwayEmbedding
class BEiT3(nn.Module):
@ -29,9 +29,12 @@ class BEiT3(nn.Module):
contain_mask_token=True,
prepend_cls_token=True,
)
embed_positions = MultiwayWrapper(
args,
PositionalEmbedding(args.max_source_positions, args.encoder_embed_dim),
# being consistent with Fairseq, which starts from 2 for position embedding
embed_positions = MutliwayEmbedding(
modules=[
PositionalEmbedding(self.vision_embed.num_position_embeddings() + 2, args.encoder_embed_dim),
PositionalEmbedding(args.max_source_positions, args.encoder_embed_dim),
],
dim=1,
)
self.encoder = Encoder(
@ -47,7 +50,10 @@ class BEiT3(nn.Module):
textual_tokens=None,
visual_tokens=None,
text_padding_position=None,
attn_mask=None,
vision_masked_position=None,
incremental_state=None,
positions=None,
):
assert textual_tokens is not None or visual_tokens is not None
@ -79,8 +85,12 @@ class BEiT3(nn.Module):
encoder_out = self.encoder(
src_tokens=None,
encoder_padding_mask=encoder_padding_mask,
attn_mask=attn_mask,
token_embeddings=x,
multiway_split_position=multiway_split_position,
incremental_state=incremental_state,
positions=positions,
)
encoder_out["multiway_split_position"] = multiway_split_position
return encoder_out