torchscale/torchscale/model/BEiT3.py
2022-11-26 09:01:02 -08:00

87 lines
2.7 KiB
Python

# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
import torch
import torch.nn as nn
from torchscale.architecture.encoder import Encoder
from torchscale.component.embedding import (
PositionalEmbedding,
TextEmbedding,
VisionEmbedding,
)
from torchscale.component.multiway_network import MultiwayWrapper
class BEiT3(nn.Module):
def __init__(self, args, **kwargs):
super().__init__()
self.args = args
assert args.multiway
assert args.vocab_size > 0
assert not args.share_encoder_input_output_embed
self.text_embed = TextEmbedding(args.vocab_size, args.encoder_embed_dim)
self.vision_embed = VisionEmbedding(
args.img_size,
args.patch_size,
args.in_chans,
args.encoder_embed_dim,
contain_mask_token=True,
prepend_cls_token=True,
)
embed_positions = MultiwayWrapper(
args,
PositionalEmbedding(args.max_source_positions, args.encoder_embed_dim),
dim=1,
)
self.encoder = Encoder(
args,
embed_tokens=None,
embed_positions=embed_positions,
output_projection=None,
is_encoder_decoder=False,
)
def forward(
self,
textual_tokens=None,
visual_tokens=None,
text_padding_position=None,
vision_masked_position=None,
):
assert textual_tokens is not None or visual_tokens is not None
if textual_tokens is None:
x = self.vision_embed(visual_tokens, vision_masked_position)
encoder_padding_mask = None
multiway_split_position = -1
elif visual_tokens is None:
x = self.text_embed(textual_tokens)
encoder_padding_mask = text_padding_position
multiway_split_position = 0
else:
x1 = self.vision_embed(visual_tokens, vision_masked_position)
multiway_split_position = x1.size(1)
x2 = self.text_embed(textual_tokens)
x = torch.cat([x1, x2], dim=1)
if text_padding_position is not None:
encoder_padding_mask = torch.cat(
[
torch.zeros(x1.shape[:-1]).to(x1.device).bool(),
text_padding_position,
],
dim=1,
)
else:
encoder_padding_mask = None
encoder_out = self.encoder(
src_tokens=None,
encoder_padding_mask=encoder_padding_mask,
token_embeddings=x,
multiway_split_position=multiway_split_position,
)
return encoder_out