Update MT config

This commit is contained in:
Shuming Ma 2023-07-31 09:17:03 -07:00
parent ea07735c7b
commit 5356b252c4

View File

@ -22,7 +22,7 @@ from fairseq.models.transformer import Embedding
from fairseq.modules import PositionalEmbedding
from torch import Tensor
from torchscale.architecture.config import DecoderConfig, EncoderConfig
from torchscale.architecture.config import DecoderConfig, EncoderConfig, EncoderDecoderConfig
from torchscale.architecture.encoder import Encoder
from .language_modeling import LMDecoder as MTDecoder
@ -308,7 +308,7 @@ class TranslationModel(FairseqEncoderDecoderModel):
@classmethod
def build_encoder(cls, args, embed_tokens, embed_positions, dictionary):
config = EncoderConfig()
config = EncoderDecoderConfig()
config.override(args)
return MTEncoder(
@ -323,7 +323,7 @@ class TranslationModel(FairseqEncoderDecoderModel):
def build_decoder(
cls, args, embed_tokens, embed_positions, output_projection, dictionary
):
config = DecoderConfig()
config = EncoderDecoderConfig()
config.override(args)
return MTDecoder(