diff --git a/examples/fairseq/models/machine_translation.py b/examples/fairseq/models/machine_translation.py index 9063da3..2aa8832 100644 --- a/examples/fairseq/models/machine_translation.py +++ b/examples/fairseq/models/machine_translation.py @@ -22,7 +22,7 @@ from fairseq.models.transformer import Embedding from fairseq.modules import PositionalEmbedding from torch import Tensor -from torchscale.architecture.config import DecoderConfig, EncoderConfig +from torchscale.architecture.config import DecoderConfig, EncoderConfig, EncoderDecoderConfig from torchscale.architecture.encoder import Encoder from .language_modeling import LMDecoder as MTDecoder @@ -308,7 +308,7 @@ class TranslationModel(FairseqEncoderDecoderModel): @classmethod def build_encoder(cls, args, embed_tokens, embed_positions, dictionary): - config = EncoderConfig() + config = EncoderDecoderConfig() config.override(args) return MTEncoder( @@ -323,7 +323,7 @@ class TranslationModel(FairseqEncoderDecoderModel): def build_decoder( cls, args, embed_tokens, embed_positions, output_projection, dictionary ): - config = DecoderConfig() + config = EncoderDecoderConfig() config.override(args) return MTDecoder(