b3 incremental decoding
This commit is contained in:
parent
891f84f302
commit
599df73687
|
@ -9,6 +9,7 @@ class EncoderConfig(object):
|
|||
self.encoder_ffn_embed_dim = kwargs.pop("encoder_ffn_embed_dim", 3072)
|
||||
self.encoder_layers = kwargs.pop("encoder_layers", 12)
|
||||
self.encoder_normalize_before = kwargs.pop("encoder_normalize_before", True)
|
||||
self.normalize_output = kwargs.pop("normalize_output", True)
|
||||
self.activation_fn = kwargs.pop("activation_fn", "gelu")
|
||||
self.dropout = kwargs.pop("dropout", 0.0)
|
||||
self.drop_path_rate = kwargs.pop("drop_path_rate", 0.0)
|
||||
|
|
|
@ -113,7 +113,7 @@ class EncoderLayer(nn.Module):
|
|||
def residual_connection(self, x, residual):
|
||||
return residual * self.alpha + x
|
||||
|
||||
def forward(self, x, encoder_padding_mask, attn_mask=None, rel_pos=None, multiway_split_position=None):
|
||||
def forward(self, x, encoder_padding_mask, attn_mask=None, rel_pos=None, multiway_split_position=None, incremental_state=None):
|
||||
if multiway_split_position is not None:
|
||||
assert self.args.multiway
|
||||
self.apply(set_split_position(multiway_split_position))
|
||||
|
@ -131,6 +131,7 @@ class EncoderLayer(nn.Module):
|
|||
key_padding_mask=encoder_padding_mask,
|
||||
attn_mask=attn_mask,
|
||||
rel_pos=rel_pos,
|
||||
incremental_state=incremental_state,
|
||||
)
|
||||
x = self.dropout_module(x)
|
||||
|
||||
|
@ -214,7 +215,7 @@ class Encoder(nn.Module):
|
|||
)
|
||||
self.num_layers = len(self.layers)
|
||||
|
||||
if args.encoder_normalize_before:
|
||||
if args.encoder_normalize_before and args.normalize_output:
|
||||
self.layer_norm = MultiwayWrapper(args, LayerNorm(embed_dim, eps=args.layernorm_eps))
|
||||
else:
|
||||
self.layer_norm = None
|
||||
|
@ -308,15 +309,16 @@ class Encoder(nn.Module):
|
|||
self,
|
||||
src_tokens,
|
||||
token_embedding=None,
|
||||
positions=None,
|
||||
):
|
||||
if token_embedding is None:
|
||||
token_embedding = self.embed_tokens(src_tokens)
|
||||
x = embed = self.embed_scale * token_embedding
|
||||
if self.embed_positions is not None:
|
||||
if src_tokens is not None:
|
||||
x = embed + self.embed_positions(src_tokens)
|
||||
x = embed + self.embed_positions(src_tokens, positions=positions)
|
||||
else:
|
||||
x = embed + self.embed_positions(x)
|
||||
x = embed + self.embed_positions(x, positions=positions)
|
||||
if self.layernorm_embedding is not None:
|
||||
x = self.layernorm_embedding(x)
|
||||
x = self.dropout_module(x)
|
||||
|
@ -326,10 +328,13 @@ class Encoder(nn.Module):
|
|||
self,
|
||||
src_tokens,
|
||||
encoder_padding_mask=None,
|
||||
attn_mask=None,
|
||||
return_all_hiddens=False,
|
||||
token_embeddings=None,
|
||||
multiway_split_position=None,
|
||||
features_only=False,
|
||||
incremental_state=None,
|
||||
positions=None,
|
||||
**kwargs
|
||||
):
|
||||
assert src_tokens is not None or token_embeddings is not None
|
||||
|
@ -349,7 +354,7 @@ class Encoder(nn.Module):
|
|||
assert self.args.multiway
|
||||
self.apply(set_split_position(multiway_split_position))
|
||||
|
||||
x, encoder_embedding = self.forward_embedding(src_tokens, token_embeddings)
|
||||
x, encoder_embedding = self.forward_embedding(src_tokens, token_embeddings, positions)
|
||||
x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x))
|
||||
|
||||
encoder_states = []
|
||||
|
@ -363,10 +368,16 @@ class Encoder(nn.Module):
|
|||
batch_size=x.size(0), qlen=x.size(1), klen=x.size(1)
|
||||
)
|
||||
|
||||
# incremental_state is not None during inference if we use the bidirectional encoder as a generator as in s2s-ft (https://arxiv.org/abs/2110.13640)
|
||||
l_aux = []
|
||||
for layer in self.layers:
|
||||
for idx, layer in enumerate(self.layers):
|
||||
x, l_aux_i = layer(
|
||||
x, encoder_padding_mask=encoder_padding_mask, rel_pos=rel_pos_bias, multiway_split_position=multiway_split_position
|
||||
x,
|
||||
encoder_padding_mask=encoder_padding_mask if incremental_state is None else None,
|
||||
attn_mask=attn_mask,
|
||||
rel_pos=rel_pos_bias,
|
||||
multiway_split_position=multiway_split_position,
|
||||
incremental_state=incremental_state[idx] if incremental_state is not None else None,
|
||||
)
|
||||
if return_all_hiddens:
|
||||
assert encoder_states is not None
|
||||
|
|
|
@ -60,6 +60,12 @@ class VisionEmbedding(nn.Module):
|
|||
else:
|
||||
self.cls_token = None
|
||||
|
||||
def num_position_embeddings(self):
|
||||
if self.cls_token is None:
|
||||
return self.num_patches
|
||||
else:
|
||||
return self.num_patches + 1
|
||||
|
||||
def forward(self, x, masked_position=None, **kwargs):
|
||||
B, C, H, W = x.shape
|
||||
assert (
|
||||
|
|
|
@ -43,3 +43,13 @@ class MultiwayNetwork(nn.Module):
|
|||
# x1, x2 = x[:self.split_position], x[self.split_position:]
|
||||
y1, y2 = self.A(x1, **kwargs), self.B(x2, **kwargs)
|
||||
return torch.cat([y1, y2], dim=self.dim)
|
||||
|
||||
|
||||
class MutliwayEmbedding(MultiwayNetwork):
|
||||
def __init__(self, modules, dim=1):
|
||||
super(MultiwayNetwork, self).__init__()
|
||||
self.dim = dim
|
||||
assert len(modules) == 2
|
||||
self.A = modules[0]
|
||||
self.B = modules[1]
|
||||
self.split_position = -1
|
|
@ -10,7 +10,7 @@ from torchscale.component.embedding import (
|
|||
TextEmbedding,
|
||||
VisionEmbedding,
|
||||
)
|
||||
from torchscale.component.multiway_network import MultiwayWrapper
|
||||
from torchscale.component.multiway_network import MutliwayEmbedding
|
||||
|
||||
|
||||
class BEiT3(nn.Module):
|
||||
|
@ -29,9 +29,12 @@ class BEiT3(nn.Module):
|
|||
contain_mask_token=True,
|
||||
prepend_cls_token=True,
|
||||
)
|
||||
embed_positions = MultiwayWrapper(
|
||||
args,
|
||||
PositionalEmbedding(args.max_source_positions, args.encoder_embed_dim),
|
||||
# being consistent with Fairseq, which starts from 2 for position embedding
|
||||
embed_positions = MutliwayEmbedding(
|
||||
modules=[
|
||||
PositionalEmbedding(self.vision_embed.num_position_embeddings() + 2, args.encoder_embed_dim),
|
||||
PositionalEmbedding(args.max_source_positions, args.encoder_embed_dim),
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
self.encoder = Encoder(
|
||||
|
@ -47,7 +50,10 @@ class BEiT3(nn.Module):
|
|||
textual_tokens=None,
|
||||
visual_tokens=None,
|
||||
text_padding_position=None,
|
||||
attn_mask=None,
|
||||
vision_masked_position=None,
|
||||
incremental_state=None,
|
||||
positions=None,
|
||||
):
|
||||
assert textual_tokens is not None or visual_tokens is not None
|
||||
|
||||
|
@ -79,8 +85,12 @@ class BEiT3(nn.Module):
|
|||
encoder_out = self.encoder(
|
||||
src_tokens=None,
|
||||
encoder_padding_mask=encoder_padding_mask,
|
||||
attn_mask=attn_mask,
|
||||
token_embeddings=x,
|
||||
multiway_split_position=multiway_split_position,
|
||||
incremental_state=incremental_state,
|
||||
positions=positions,
|
||||
)
|
||||
encoder_out["multiway_split_position"] = multiway_split_position
|
||||
|
||||
return encoder_out
|
||||
|
|
Loading…
Reference in New Issue
Block a user