b3 incremental decoding
This commit is contained in:
parent
891f84f302
commit
599df73687
|
@ -9,6 +9,7 @@ class EncoderConfig(object):
|
||||||
self.encoder_ffn_embed_dim = kwargs.pop("encoder_ffn_embed_dim", 3072)
|
self.encoder_ffn_embed_dim = kwargs.pop("encoder_ffn_embed_dim", 3072)
|
||||||
self.encoder_layers = kwargs.pop("encoder_layers", 12)
|
self.encoder_layers = kwargs.pop("encoder_layers", 12)
|
||||||
self.encoder_normalize_before = kwargs.pop("encoder_normalize_before", True)
|
self.encoder_normalize_before = kwargs.pop("encoder_normalize_before", True)
|
||||||
|
self.normalize_output = kwargs.pop("normalize_output", True)
|
||||||
self.activation_fn = kwargs.pop("activation_fn", "gelu")
|
self.activation_fn = kwargs.pop("activation_fn", "gelu")
|
||||||
self.dropout = kwargs.pop("dropout", 0.0)
|
self.dropout = kwargs.pop("dropout", 0.0)
|
||||||
self.drop_path_rate = kwargs.pop("drop_path_rate", 0.0)
|
self.drop_path_rate = kwargs.pop("drop_path_rate", 0.0)
|
||||||
|
|
|
@ -113,7 +113,7 @@ class EncoderLayer(nn.Module):
|
||||||
def residual_connection(self, x, residual):
|
def residual_connection(self, x, residual):
|
||||||
return residual * self.alpha + x
|
return residual * self.alpha + x
|
||||||
|
|
||||||
def forward(self, x, encoder_padding_mask, attn_mask=None, rel_pos=None, multiway_split_position=None):
|
def forward(self, x, encoder_padding_mask, attn_mask=None, rel_pos=None, multiway_split_position=None, incremental_state=None):
|
||||||
if multiway_split_position is not None:
|
if multiway_split_position is not None:
|
||||||
assert self.args.multiway
|
assert self.args.multiway
|
||||||
self.apply(set_split_position(multiway_split_position))
|
self.apply(set_split_position(multiway_split_position))
|
||||||
|
@ -131,6 +131,7 @@ class EncoderLayer(nn.Module):
|
||||||
key_padding_mask=encoder_padding_mask,
|
key_padding_mask=encoder_padding_mask,
|
||||||
attn_mask=attn_mask,
|
attn_mask=attn_mask,
|
||||||
rel_pos=rel_pos,
|
rel_pos=rel_pos,
|
||||||
|
incremental_state=incremental_state,
|
||||||
)
|
)
|
||||||
x = self.dropout_module(x)
|
x = self.dropout_module(x)
|
||||||
|
|
||||||
|
@ -214,7 +215,7 @@ class Encoder(nn.Module):
|
||||||
)
|
)
|
||||||
self.num_layers = len(self.layers)
|
self.num_layers = len(self.layers)
|
||||||
|
|
||||||
if args.encoder_normalize_before:
|
if args.encoder_normalize_before and args.normalize_output:
|
||||||
self.layer_norm = MultiwayWrapper(args, LayerNorm(embed_dim, eps=args.layernorm_eps))
|
self.layer_norm = MultiwayWrapper(args, LayerNorm(embed_dim, eps=args.layernorm_eps))
|
||||||
else:
|
else:
|
||||||
self.layer_norm = None
|
self.layer_norm = None
|
||||||
|
@ -308,15 +309,16 @@ class Encoder(nn.Module):
|
||||||
self,
|
self,
|
||||||
src_tokens,
|
src_tokens,
|
||||||
token_embedding=None,
|
token_embedding=None,
|
||||||
|
positions=None,
|
||||||
):
|
):
|
||||||
if token_embedding is None:
|
if token_embedding is None:
|
||||||
token_embedding = self.embed_tokens(src_tokens)
|
token_embedding = self.embed_tokens(src_tokens)
|
||||||
x = embed = self.embed_scale * token_embedding
|
x = embed = self.embed_scale * token_embedding
|
||||||
if self.embed_positions is not None:
|
if self.embed_positions is not None:
|
||||||
if src_tokens is not None:
|
if src_tokens is not None:
|
||||||
x = embed + self.embed_positions(src_tokens)
|
x = embed + self.embed_positions(src_tokens, positions=positions)
|
||||||
else:
|
else:
|
||||||
x = embed + self.embed_positions(x)
|
x = embed + self.embed_positions(x, positions=positions)
|
||||||
if self.layernorm_embedding is not None:
|
if self.layernorm_embedding is not None:
|
||||||
x = self.layernorm_embedding(x)
|
x = self.layernorm_embedding(x)
|
||||||
x = self.dropout_module(x)
|
x = self.dropout_module(x)
|
||||||
|
@ -326,10 +328,13 @@ class Encoder(nn.Module):
|
||||||
self,
|
self,
|
||||||
src_tokens,
|
src_tokens,
|
||||||
encoder_padding_mask=None,
|
encoder_padding_mask=None,
|
||||||
|
attn_mask=None,
|
||||||
return_all_hiddens=False,
|
return_all_hiddens=False,
|
||||||
token_embeddings=None,
|
token_embeddings=None,
|
||||||
multiway_split_position=None,
|
multiway_split_position=None,
|
||||||
features_only=False,
|
features_only=False,
|
||||||
|
incremental_state=None,
|
||||||
|
positions=None,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
assert src_tokens is not None or token_embeddings is not None
|
assert src_tokens is not None or token_embeddings is not None
|
||||||
|
@ -349,7 +354,7 @@ class Encoder(nn.Module):
|
||||||
assert self.args.multiway
|
assert self.args.multiway
|
||||||
self.apply(set_split_position(multiway_split_position))
|
self.apply(set_split_position(multiway_split_position))
|
||||||
|
|
||||||
x, encoder_embedding = self.forward_embedding(src_tokens, token_embeddings)
|
x, encoder_embedding = self.forward_embedding(src_tokens, token_embeddings, positions)
|
||||||
x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x))
|
x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x))
|
||||||
|
|
||||||
encoder_states = []
|
encoder_states = []
|
||||||
|
@ -363,10 +368,16 @@ class Encoder(nn.Module):
|
||||||
batch_size=x.size(0), qlen=x.size(1), klen=x.size(1)
|
batch_size=x.size(0), qlen=x.size(1), klen=x.size(1)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# incremental_state is not None during inference if we use the bidirectional encoder as a generator as in s2s-ft (https://arxiv.org/abs/2110.13640)
|
||||||
l_aux = []
|
l_aux = []
|
||||||
for layer in self.layers:
|
for idx, layer in enumerate(self.layers):
|
||||||
x, l_aux_i = layer(
|
x, l_aux_i = layer(
|
||||||
x, encoder_padding_mask=encoder_padding_mask, rel_pos=rel_pos_bias, multiway_split_position=multiway_split_position
|
x,
|
||||||
|
encoder_padding_mask=encoder_padding_mask if incremental_state is None else None,
|
||||||
|
attn_mask=attn_mask,
|
||||||
|
rel_pos=rel_pos_bias,
|
||||||
|
multiway_split_position=multiway_split_position,
|
||||||
|
incremental_state=incremental_state[idx] if incremental_state is not None else None,
|
||||||
)
|
)
|
||||||
if return_all_hiddens:
|
if return_all_hiddens:
|
||||||
assert encoder_states is not None
|
assert encoder_states is not None
|
||||||
|
|
|
@ -60,6 +60,12 @@ class VisionEmbedding(nn.Module):
|
||||||
else:
|
else:
|
||||||
self.cls_token = None
|
self.cls_token = None
|
||||||
|
|
||||||
|
def num_position_embeddings(self):
|
||||||
|
if self.cls_token is None:
|
||||||
|
return self.num_patches
|
||||||
|
else:
|
||||||
|
return self.num_patches + 1
|
||||||
|
|
||||||
def forward(self, x, masked_position=None, **kwargs):
|
def forward(self, x, masked_position=None, **kwargs):
|
||||||
B, C, H, W = x.shape
|
B, C, H, W = x.shape
|
||||||
assert (
|
assert (
|
||||||
|
|
|
@ -43,3 +43,13 @@ class MultiwayNetwork(nn.Module):
|
||||||
# x1, x2 = x[:self.split_position], x[self.split_position:]
|
# x1, x2 = x[:self.split_position], x[self.split_position:]
|
||||||
y1, y2 = self.A(x1, **kwargs), self.B(x2, **kwargs)
|
y1, y2 = self.A(x1, **kwargs), self.B(x2, **kwargs)
|
||||||
return torch.cat([y1, y2], dim=self.dim)
|
return torch.cat([y1, y2], dim=self.dim)
|
||||||
|
|
||||||
|
|
||||||
|
class MutliwayEmbedding(MultiwayNetwork):
|
||||||
|
def __init__(self, modules, dim=1):
|
||||||
|
super(MultiwayNetwork, self).__init__()
|
||||||
|
self.dim = dim
|
||||||
|
assert len(modules) == 2
|
||||||
|
self.A = modules[0]
|
||||||
|
self.B = modules[1]
|
||||||
|
self.split_position = -1
|
|
@ -10,7 +10,7 @@ from torchscale.component.embedding import (
|
||||||
TextEmbedding,
|
TextEmbedding,
|
||||||
VisionEmbedding,
|
VisionEmbedding,
|
||||||
)
|
)
|
||||||
from torchscale.component.multiway_network import MultiwayWrapper
|
from torchscale.component.multiway_network import MutliwayEmbedding
|
||||||
|
|
||||||
|
|
||||||
class BEiT3(nn.Module):
|
class BEiT3(nn.Module):
|
||||||
|
@ -29,9 +29,12 @@ class BEiT3(nn.Module):
|
||||||
contain_mask_token=True,
|
contain_mask_token=True,
|
||||||
prepend_cls_token=True,
|
prepend_cls_token=True,
|
||||||
)
|
)
|
||||||
embed_positions = MultiwayWrapper(
|
# being consistent with Fairseq, which starts from 2 for position embedding
|
||||||
args,
|
embed_positions = MutliwayEmbedding(
|
||||||
PositionalEmbedding(args.max_source_positions, args.encoder_embed_dim),
|
modules=[
|
||||||
|
PositionalEmbedding(self.vision_embed.num_position_embeddings() + 2, args.encoder_embed_dim),
|
||||||
|
PositionalEmbedding(args.max_source_positions, args.encoder_embed_dim),
|
||||||
|
],
|
||||||
dim=1,
|
dim=1,
|
||||||
)
|
)
|
||||||
self.encoder = Encoder(
|
self.encoder = Encoder(
|
||||||
|
@ -47,7 +50,10 @@ class BEiT3(nn.Module):
|
||||||
textual_tokens=None,
|
textual_tokens=None,
|
||||||
visual_tokens=None,
|
visual_tokens=None,
|
||||||
text_padding_position=None,
|
text_padding_position=None,
|
||||||
|
attn_mask=None,
|
||||||
vision_masked_position=None,
|
vision_masked_position=None,
|
||||||
|
incremental_state=None,
|
||||||
|
positions=None,
|
||||||
):
|
):
|
||||||
assert textual_tokens is not None or visual_tokens is not None
|
assert textual_tokens is not None or visual_tokens is not None
|
||||||
|
|
||||||
|
@ -79,8 +85,12 @@ class BEiT3(nn.Module):
|
||||||
encoder_out = self.encoder(
|
encoder_out = self.encoder(
|
||||||
src_tokens=None,
|
src_tokens=None,
|
||||||
encoder_padding_mask=encoder_padding_mask,
|
encoder_padding_mask=encoder_padding_mask,
|
||||||
|
attn_mask=attn_mask,
|
||||||
token_embeddings=x,
|
token_embeddings=x,
|
||||||
multiway_split_position=multiway_split_position,
|
multiway_split_position=multiway_split_position,
|
||||||
|
incremental_state=incremental_state,
|
||||||
|
positions=positions,
|
||||||
)
|
)
|
||||||
|
encoder_out["multiway_split_position"] = multiway_split_position
|
||||||
|
|
||||||
return encoder_out
|
return encoder_out
|
||||||
|
|
Loading…
Reference in New Issue
Block a user