# Copyright (c) 2022 Microsoft # Licensed under The MIT License [see LICENSE for details] import pytest import torch from torchscale.architecture.config import EncoderDecoderConfig from torchscale.architecture.encoder_decoder import EncoderDecoder from torchscale.component.embedding import PositionalEmbedding, TextEmbedding testcases = [ {}, {"vocab_size": 64000}, {"activation_fn": "relu"}, {"drop_path_rate": 0.1}, {"encoder_normalize_before": False, "decoder_normalize_before": False}, {"no_scale_embedding": False}, {"layernorm_embedding": True}, {"rel_pos_buckets": 32, "max_rel_pos": 256}, { "deepnorm": True, "subln": False, "encoder_normalize_before": False, "decoder_normalize_before": False, }, {"bert_init": True}, {"multiway": True}, {"share_decoder_input_output_embed": True}, {"share_all_embeddings": True}, {"checkpoint_activations": True}, {"fsdp": True}, ] @pytest.mark.parametrize("args", testcases) def test_decoder(args): config = EncoderDecoderConfig(**args) model = EncoderDecoder( config, encoder_embed_tokens=TextEmbedding(64000, config.encoder_embed_dim), decoder_embed_tokens=TextEmbedding(64000, config.decoder_embed_dim), encoder_embed_positions=PositionalEmbedding( config.max_source_positions, config.encoder_embed_dim ), decoder_embed_positions=PositionalEmbedding( config.max_target_positions, config.decoder_embed_dim ), ) src_tokens = torch.ones(2, 20).long() prev_output_tokens = torch.ones(2, 10).long() model( src_tokens=src_tokens, prev_output_tokens=prev_output_tokens, features_only=True, )