87 lines
2.7 KiB
Python
87 lines
2.7 KiB
Python
# Copyright (c) 2022 Microsoft
|
|
# Licensed under The MIT License [see LICENSE for details]
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from torchscale.architecture.encoder import Encoder
|
|
from torchscale.component.embedding import (
|
|
PositionalEmbedding,
|
|
TextEmbedding,
|
|
VisionEmbedding,
|
|
)
|
|
from torchscale.component.multiway_network import MultiwayWrapper
|
|
|
|
|
|
class BEiT3(nn.Module):
|
|
def __init__(self, args, **kwargs):
|
|
super().__init__()
|
|
self.args = args
|
|
assert args.multiway
|
|
assert args.vocab_size > 0
|
|
assert not args.share_encoder_input_output_embed
|
|
self.text_embed = TextEmbedding(args.vocab_size, args.encoder_embed_dim)
|
|
self.vision_embed = VisionEmbedding(
|
|
args.img_size,
|
|
args.patch_size,
|
|
args.in_chans,
|
|
args.encoder_embed_dim,
|
|
contain_mask_token=True,
|
|
prepend_cls_token=True,
|
|
)
|
|
embed_positions = MultiwayWrapper(
|
|
args,
|
|
PositionalEmbedding(args.max_source_positions, args.encoder_embed_dim),
|
|
dim=1,
|
|
)
|
|
self.encoder = Encoder(
|
|
args,
|
|
embed_tokens=None,
|
|
embed_positions=embed_positions,
|
|
output_projection=None,
|
|
is_encoder_decoder=False,
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
textual_tokens=None,
|
|
visual_tokens=None,
|
|
text_padding_position=None,
|
|
vision_masked_position=None,
|
|
):
|
|
assert textual_tokens is not None or visual_tokens is not None
|
|
|
|
if textual_tokens is None:
|
|
x = self.vision_embed(visual_tokens, vision_masked_position)
|
|
encoder_padding_mask = None
|
|
multiway_split_position = -1
|
|
elif visual_tokens is None:
|
|
x = self.text_embed(textual_tokens)
|
|
encoder_padding_mask = text_padding_position
|
|
multiway_split_position = 0
|
|
else:
|
|
x1 = self.vision_embed(visual_tokens, vision_masked_position)
|
|
multiway_split_position = x1.size(1)
|
|
x2 = self.text_embed(textual_tokens)
|
|
x = torch.cat([x1, x2], dim=1)
|
|
|
|
if text_padding_position is not None:
|
|
encoder_padding_mask = torch.cat(
|
|
[
|
|
torch.zeros(x1.shape[:-1]).to(x1.device).bool(),
|
|
text_padding_position,
|
|
],
|
|
dim=1,
|
|
)
|
|
else:
|
|
encoder_padding_mask = None
|
|
|
|
encoder_out = self.encoder(
|
|
src_tokens=None,
|
|
encoder_padding_mask=encoder_padding_mask,
|
|
token_embeddings=x,
|
|
multiway_split_position=multiway_split_position,
|
|
)
|
|
|
|
return encoder_out
|