b3 incremental decoding

This commit is contained in:
Wenhui Wang 2023-03-09 12:02:36 +08:00
parent 891f84f302
commit 599df73687
5 changed files with 49 additions and 11 deletions

View File

@ -9,6 +9,7 @@ class EncoderConfig(object):
self.encoder_ffn_embed_dim = kwargs.pop("encoder_ffn_embed_dim", 3072) self.encoder_ffn_embed_dim = kwargs.pop("encoder_ffn_embed_dim", 3072)
self.encoder_layers = kwargs.pop("encoder_layers", 12) self.encoder_layers = kwargs.pop("encoder_layers", 12)
self.encoder_normalize_before = kwargs.pop("encoder_normalize_before", True) self.encoder_normalize_before = kwargs.pop("encoder_normalize_before", True)
self.normalize_output = kwargs.pop("normalize_output", True)
self.activation_fn = kwargs.pop("activation_fn", "gelu") self.activation_fn = kwargs.pop("activation_fn", "gelu")
self.dropout = kwargs.pop("dropout", 0.0) self.dropout = kwargs.pop("dropout", 0.0)
self.drop_path_rate = kwargs.pop("drop_path_rate", 0.0) self.drop_path_rate = kwargs.pop("drop_path_rate", 0.0)

View File

@ -113,7 +113,7 @@ class EncoderLayer(nn.Module):
def residual_connection(self, x, residual): def residual_connection(self, x, residual):
return residual * self.alpha + x return residual * self.alpha + x
def forward(self, x, encoder_padding_mask, attn_mask=None, rel_pos=None, multiway_split_position=None): def forward(self, x, encoder_padding_mask, attn_mask=None, rel_pos=None, multiway_split_position=None, incremental_state=None):
if multiway_split_position is not None: if multiway_split_position is not None:
assert self.args.multiway assert self.args.multiway
self.apply(set_split_position(multiway_split_position)) self.apply(set_split_position(multiway_split_position))
@ -131,6 +131,7 @@ class EncoderLayer(nn.Module):
key_padding_mask=encoder_padding_mask, key_padding_mask=encoder_padding_mask,
attn_mask=attn_mask, attn_mask=attn_mask,
rel_pos=rel_pos, rel_pos=rel_pos,
incremental_state=incremental_state,
) )
x = self.dropout_module(x) x = self.dropout_module(x)
@ -214,7 +215,7 @@ class Encoder(nn.Module):
) )
self.num_layers = len(self.layers) self.num_layers = len(self.layers)
if args.encoder_normalize_before: if args.encoder_normalize_before and args.normalize_output:
self.layer_norm = MultiwayWrapper(args, LayerNorm(embed_dim, eps=args.layernorm_eps)) self.layer_norm = MultiwayWrapper(args, LayerNorm(embed_dim, eps=args.layernorm_eps))
else: else:
self.layer_norm = None self.layer_norm = None
@ -308,15 +309,16 @@ class Encoder(nn.Module):
self, self,
src_tokens, src_tokens,
token_embedding=None, token_embedding=None,
positions=None,
): ):
if token_embedding is None: if token_embedding is None:
token_embedding = self.embed_tokens(src_tokens) token_embedding = self.embed_tokens(src_tokens)
x = embed = self.embed_scale * token_embedding x = embed = self.embed_scale * token_embedding
if self.embed_positions is not None: if self.embed_positions is not None:
if src_tokens is not None: if src_tokens is not None:
x = embed + self.embed_positions(src_tokens) x = embed + self.embed_positions(src_tokens, positions=positions)
else: else:
x = embed + self.embed_positions(x) x = embed + self.embed_positions(x, positions=positions)
if self.layernorm_embedding is not None: if self.layernorm_embedding is not None:
x = self.layernorm_embedding(x) x = self.layernorm_embedding(x)
x = self.dropout_module(x) x = self.dropout_module(x)
@ -326,10 +328,13 @@ class Encoder(nn.Module):
self, self,
src_tokens, src_tokens,
encoder_padding_mask=None, encoder_padding_mask=None,
attn_mask=None,
return_all_hiddens=False, return_all_hiddens=False,
token_embeddings=None, token_embeddings=None,
multiway_split_position=None, multiway_split_position=None,
features_only=False, features_only=False,
incremental_state=None,
positions=None,
**kwargs **kwargs
): ):
assert src_tokens is not None or token_embeddings is not None assert src_tokens is not None or token_embeddings is not None
@ -349,7 +354,7 @@ class Encoder(nn.Module):
assert self.args.multiway assert self.args.multiway
self.apply(set_split_position(multiway_split_position)) self.apply(set_split_position(multiway_split_position))
x, encoder_embedding = self.forward_embedding(src_tokens, token_embeddings) x, encoder_embedding = self.forward_embedding(src_tokens, token_embeddings, positions)
x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x)) x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x))
encoder_states = [] encoder_states = []
@ -363,10 +368,16 @@ class Encoder(nn.Module):
batch_size=x.size(0), qlen=x.size(1), klen=x.size(1) batch_size=x.size(0), qlen=x.size(1), klen=x.size(1)
) )
# incremental_state is not None during inference if we use the bidirectional encoder as a generator as in s2s-ft (https://arxiv.org/abs/2110.13640)
l_aux = [] l_aux = []
for layer in self.layers: for idx, layer in enumerate(self.layers):
x, l_aux_i = layer( x, l_aux_i = layer(
x, encoder_padding_mask=encoder_padding_mask, rel_pos=rel_pos_bias, multiway_split_position=multiway_split_position x,
encoder_padding_mask=encoder_padding_mask if incremental_state is None else None,
attn_mask=attn_mask,
rel_pos=rel_pos_bias,
multiway_split_position=multiway_split_position,
incremental_state=incremental_state[idx] if incremental_state is not None else None,
) )
if return_all_hiddens: if return_all_hiddens:
assert encoder_states is not None assert encoder_states is not None

View File

@ -60,6 +60,12 @@ class VisionEmbedding(nn.Module):
else: else:
self.cls_token = None self.cls_token = None
def num_position_embeddings(self):
if self.cls_token is None:
return self.num_patches
else:
return self.num_patches + 1
def forward(self, x, masked_position=None, **kwargs): def forward(self, x, masked_position=None, **kwargs):
B, C, H, W = x.shape B, C, H, W = x.shape
assert ( assert (

View File

@ -43,3 +43,13 @@ class MultiwayNetwork(nn.Module):
# x1, x2 = x[:self.split_position], x[self.split_position:] # x1, x2 = x[:self.split_position], x[self.split_position:]
y1, y2 = self.A(x1, **kwargs), self.B(x2, **kwargs) y1, y2 = self.A(x1, **kwargs), self.B(x2, **kwargs)
return torch.cat([y1, y2], dim=self.dim) return torch.cat([y1, y2], dim=self.dim)
class MutliwayEmbedding(MultiwayNetwork):
def __init__(self, modules, dim=1):
super(MultiwayNetwork, self).__init__()
self.dim = dim
assert len(modules) == 2
self.A = modules[0]
self.B = modules[1]
self.split_position = -1

View File

@ -10,7 +10,7 @@ from torchscale.component.embedding import (
TextEmbedding, TextEmbedding,
VisionEmbedding, VisionEmbedding,
) )
from torchscale.component.multiway_network import MultiwayWrapper from torchscale.component.multiway_network import MutliwayEmbedding
class BEiT3(nn.Module): class BEiT3(nn.Module):
@ -29,9 +29,12 @@ class BEiT3(nn.Module):
contain_mask_token=True, contain_mask_token=True,
prepend_cls_token=True, prepend_cls_token=True,
) )
embed_positions = MultiwayWrapper( # being consistent with Fairseq, which starts from 2 for position embedding
args, embed_positions = MutliwayEmbedding(
PositionalEmbedding(args.max_source_positions, args.encoder_embed_dim), modules=[
PositionalEmbedding(self.vision_embed.num_position_embeddings() + 2, args.encoder_embed_dim),
PositionalEmbedding(args.max_source_positions, args.encoder_embed_dim),
],
dim=1, dim=1,
) )
self.encoder = Encoder( self.encoder = Encoder(
@ -47,7 +50,10 @@ class BEiT3(nn.Module):
textual_tokens=None, textual_tokens=None,
visual_tokens=None, visual_tokens=None,
text_padding_position=None, text_padding_position=None,
attn_mask=None,
vision_masked_position=None, vision_masked_position=None,
incremental_state=None,
positions=None,
): ):
assert textual_tokens is not None or visual_tokens is not None assert textual_tokens is not None or visual_tokens is not None
@ -79,8 +85,12 @@ class BEiT3(nn.Module):
encoder_out = self.encoder( encoder_out = self.encoder(
src_tokens=None, src_tokens=None,
encoder_padding_mask=encoder_padding_mask, encoder_padding_mask=encoder_padding_mask,
attn_mask=attn_mask,
token_embeddings=x, token_embeddings=x,
multiway_split_position=multiway_split_position, multiway_split_position=multiway_split_position,
incremental_state=incremental_state,
positions=positions,
) )
encoder_out["multiway_split_position"] = multiway_split_position
return encoder_out return encoder_out