torchscale/tests/test_encoder.py
2022-11-23 08:36:55 -08:00

32 lines
989 B
Python

# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
import pytest
from torchscale.architecture.config import EncoderConfig
from torchscale.architecture.encoder import Encoder
import torch
testcases = [
{},
{"vocab_size": 64000},
{"activation_fn": "relu"},
{"drop_path_rate": 0.1},
{"encoder_normalize_before": False},
{"no_scale_embedding": False},
{"layernorm_embedding": True},
{"rel_pos_buckets": 32, "max_rel_pos": 256},
{"deepnorm": True, "subln": False, "encoder_normalize_before": False},
{"bert_init": True},
{"multiway": True},
{"share_encoder_input_output_embed": True},
{"checkpoint_activations": True},
{"fsdp": True}
]
@pytest.mark.parametrize("args", testcases)
def test_encoder(args):
config = EncoderConfig(**args)
model = Encoder(config)
token_embeddings = torch.rand(2, 10, config.encoder_embed_dim)
model(src_tokens=None, token_embeddings=token_embeddings)