diff --git a/torchscale/architecture/config.py b/torchscale/architecture/config.py index 6aa7e16..0d2e9be 100644 --- a/torchscale/architecture/config.py +++ b/torchscale/architecture/config.py @@ -9,6 +9,7 @@ class EncoderConfig(object): self.encoder_ffn_embed_dim = kwargs.pop("encoder_ffn_embed_dim", 3072) self.encoder_layers = kwargs.pop("encoder_layers", 12) self.encoder_normalize_before = kwargs.pop("encoder_normalize_before", True) + self.normalize_output = kwargs.pop("normalize_output", True) self.activation_fn = kwargs.pop("activation_fn", "gelu") self.dropout = kwargs.pop("dropout", 0.0) self.drop_path_rate = kwargs.pop("drop_path_rate", 0.0) diff --git a/torchscale/architecture/encoder.py b/torchscale/architecture/encoder.py index 878b69b..62ab174 100644 --- a/torchscale/architecture/encoder.py +++ b/torchscale/architecture/encoder.py @@ -113,7 +113,7 @@ class EncoderLayer(nn.Module): def residual_connection(self, x, residual): return residual * self.alpha + x - def forward(self, x, encoder_padding_mask, attn_mask=None, rel_pos=None, multiway_split_position=None): + def forward(self, x, encoder_padding_mask, attn_mask=None, rel_pos=None, multiway_split_position=None, incremental_state=None): if multiway_split_position is not None: assert self.args.multiway self.apply(set_split_position(multiway_split_position)) @@ -131,6 +131,7 @@ class EncoderLayer(nn.Module): key_padding_mask=encoder_padding_mask, attn_mask=attn_mask, rel_pos=rel_pos, + incremental_state=incremental_state, ) x = self.dropout_module(x) @@ -214,7 +215,7 @@ class Encoder(nn.Module): ) self.num_layers = len(self.layers) - if args.encoder_normalize_before: + if args.encoder_normalize_before and args.normalize_output: self.layer_norm = MultiwayWrapper(args, LayerNorm(embed_dim, eps=args.layernorm_eps)) else: self.layer_norm = None @@ -308,15 +309,16 @@ class Encoder(nn.Module): self, src_tokens, token_embedding=None, + positions=None, ): if token_embedding is None: token_embedding = self.embed_tokens(src_tokens) x = embed = self.embed_scale * token_embedding if self.embed_positions is not None: if src_tokens is not None: - x = embed + self.embed_positions(src_tokens) + x = embed + self.embed_positions(src_tokens, positions=positions) else: - x = embed + self.embed_positions(x) + x = embed + self.embed_positions(x, positions=positions) if self.layernorm_embedding is not None: x = self.layernorm_embedding(x) x = self.dropout_module(x) @@ -326,10 +328,13 @@ class Encoder(nn.Module): self, src_tokens, encoder_padding_mask=None, + attn_mask=None, return_all_hiddens=False, token_embeddings=None, multiway_split_position=None, features_only=False, + incremental_state=None, + positions=None, **kwargs ): assert src_tokens is not None or token_embeddings is not None @@ -349,7 +354,7 @@ class Encoder(nn.Module): assert self.args.multiway self.apply(set_split_position(multiway_split_position)) - x, encoder_embedding = self.forward_embedding(src_tokens, token_embeddings) + x, encoder_embedding = self.forward_embedding(src_tokens, token_embeddings, positions) x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x)) encoder_states = [] @@ -363,10 +368,16 @@ class Encoder(nn.Module): batch_size=x.size(0), qlen=x.size(1), klen=x.size(1) ) + # incremental_state is not None during inference if we use the bidirectional encoder as a generator as in s2s-ft (https://arxiv.org/abs/2110.13640) l_aux = [] - for layer in self.layers: + for idx, layer in enumerate(self.layers): x, l_aux_i = layer( - x, encoder_padding_mask=encoder_padding_mask, rel_pos=rel_pos_bias, multiway_split_position=multiway_split_position + x, + encoder_padding_mask=encoder_padding_mask if incremental_state is None else None, + attn_mask=attn_mask, + rel_pos=rel_pos_bias, + multiway_split_position=multiway_split_position, + incremental_state=incremental_state[idx] if incremental_state is not None else None, ) if return_all_hiddens: assert encoder_states is not None diff --git a/torchscale/component/embedding.py b/torchscale/component/embedding.py index f6cc62e..e633d5a 100644 --- a/torchscale/component/embedding.py +++ b/torchscale/component/embedding.py @@ -60,6 +60,12 @@ class VisionEmbedding(nn.Module): else: self.cls_token = None + def num_position_embeddings(self): + if self.cls_token is None: + return self.num_patches + else: + return self.num_patches + 1 + def forward(self, x, masked_position=None, **kwargs): B, C, H, W = x.shape assert ( diff --git a/torchscale/component/multiway_network.py b/torchscale/component/multiway_network.py index d6a1ac0..a44a699 100644 --- a/torchscale/component/multiway_network.py +++ b/torchscale/component/multiway_network.py @@ -43,3 +43,13 @@ class MultiwayNetwork(nn.Module): # x1, x2 = x[:self.split_position], x[self.split_position:] y1, y2 = self.A(x1, **kwargs), self.B(x2, **kwargs) return torch.cat([y1, y2], dim=self.dim) + + +class MutliwayEmbedding(MultiwayNetwork): + def __init__(self, modules, dim=1): + super(MultiwayNetwork, self).__init__() + self.dim = dim + assert len(modules) == 2 + self.A = modules[0] + self.B = modules[1] + self.split_position = -1 \ No newline at end of file diff --git a/torchscale/model/BEiT3.py b/torchscale/model/BEiT3.py index 597a063..92737a2 100644 --- a/torchscale/model/BEiT3.py +++ b/torchscale/model/BEiT3.py @@ -10,7 +10,7 @@ from torchscale.component.embedding import ( TextEmbedding, VisionEmbedding, ) -from torchscale.component.multiway_network import MultiwayWrapper +from torchscale.component.multiway_network import MutliwayEmbedding class BEiT3(nn.Module): @@ -29,9 +29,12 @@ class BEiT3(nn.Module): contain_mask_token=True, prepend_cls_token=True, ) - embed_positions = MultiwayWrapper( - args, - PositionalEmbedding(args.max_source_positions, args.encoder_embed_dim), + # being consistent with Fairseq, which starts from 2 for position embedding + embed_positions = MutliwayEmbedding( + modules=[ + PositionalEmbedding(self.vision_embed.num_position_embeddings() + 2, args.encoder_embed_dim), + PositionalEmbedding(args.max_source_positions, args.encoder_embed_dim), + ], dim=1, ) self.encoder = Encoder( @@ -47,7 +50,10 @@ class BEiT3(nn.Module): textual_tokens=None, visual_tokens=None, text_padding_position=None, + attn_mask=None, vision_masked_position=None, + incremental_state=None, + positions=None, ): assert textual_tokens is not None or visual_tokens is not None @@ -79,8 +85,12 @@ class BEiT3(nn.Module): encoder_out = self.encoder( src_tokens=None, encoder_padding_mask=encoder_padding_mask, + attn_mask=attn_mask, token_embeddings=x, multiway_split_position=multiway_split_position, + incremental_state=incremental_state, + positions=positions, ) + encoder_out["multiway_split_position"] = multiway_split_position return encoder_out