2022-11-23 16:36:55 +00:00
|
|
|
# Copyright (c) 2022 Microsoft
|
|
|
|
# Licensed under The MIT License [see LICENSE for details]
|
|
|
|
|
2022-11-23 16:21:58 +00:00
|
|
|
import pytest
|
2022-11-26 17:01:02 +00:00
|
|
|
import torch
|
|
|
|
|
2022-11-23 16:21:58 +00:00
|
|
|
from torchscale.architecture.config import EncoderDecoderConfig
|
|
|
|
from torchscale.architecture.encoder_decoder import EncoderDecoder
|
2022-11-26 17:01:02 +00:00
|
|
|
from torchscale.component.embedding import PositionalEmbedding, TextEmbedding
|
2022-11-23 16:21:58 +00:00
|
|
|
|
|
|
|
testcases = [
|
|
|
|
{},
|
|
|
|
{"vocab_size": 64000},
|
|
|
|
{"activation_fn": "relu"},
|
|
|
|
{"drop_path_rate": 0.1},
|
|
|
|
{"encoder_normalize_before": False, "decoder_normalize_before": False},
|
|
|
|
{"no_scale_embedding": False},
|
|
|
|
{"layernorm_embedding": True},
|
|
|
|
{"rel_pos_buckets": 32, "max_rel_pos": 256},
|
2022-11-26 17:01:02 +00:00
|
|
|
{
|
|
|
|
"deepnorm": True,
|
|
|
|
"subln": False,
|
|
|
|
"encoder_normalize_before": False,
|
|
|
|
"decoder_normalize_before": False,
|
|
|
|
},
|
2022-11-23 16:21:58 +00:00
|
|
|
{"bert_init": True},
|
|
|
|
{"multiway": True},
|
|
|
|
{"share_decoder_input_output_embed": True},
|
|
|
|
{"share_all_embeddings": True},
|
|
|
|
{"checkpoint_activations": True},
|
2022-11-26 17:01:02 +00:00
|
|
|
{"fsdp": True},
|
2022-11-23 16:21:58 +00:00
|
|
|
]
|
|
|
|
|
2022-11-26 16:10:15 +00:00
|
|
|
|
2022-11-23 16:21:58 +00:00
|
|
|
@pytest.mark.parametrize("args", testcases)
|
|
|
|
def test_decoder(args):
|
|
|
|
config = EncoderDecoderConfig(**args)
|
|
|
|
model = EncoderDecoder(
|
|
|
|
config,
|
2022-11-26 16:10:15 +00:00
|
|
|
encoder_embed_tokens=TextEmbedding(64000, config.encoder_embed_dim),
|
|
|
|
decoder_embed_tokens=TextEmbedding(64000, config.decoder_embed_dim),
|
2022-11-26 17:01:02 +00:00
|
|
|
encoder_embed_positions=PositionalEmbedding(
|
|
|
|
config.max_source_positions, config.encoder_embed_dim
|
|
|
|
),
|
|
|
|
decoder_embed_positions=PositionalEmbedding(
|
|
|
|
config.max_target_positions, config.decoder_embed_dim
|
|
|
|
),
|
2022-11-23 16:21:58 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
src_tokens = torch.ones(2, 20).long()
|
|
|
|
prev_output_tokens = torch.ones(2, 10).long()
|
|
|
|
|
|
|
|
model(
|
|
|
|
src_tokens=src_tokens,
|
2022-11-26 16:10:15 +00:00
|
|
|
prev_output_tokens=prev_output_tokens,
|
2022-11-23 16:21:58 +00:00
|
|
|
features_only=True,
|
|
|
|
)
|